## Baseline submission

A notebook to generate a valid submission. Implements three local feature/matcher methods: LoFTR, DISK, and KeyNetAffNetHardNet.

Remember to enable a GPU accelerator and disable internet access, then press "submit" on the right pane.

In [2]:
# General utilities
import os
from tqdm import tqdm
from time import time
from fastprogress import progress_bar
import gc
import numpy as np
import h5py
from IPython.display import clear_output
from collections import defaultdict
from copy import deepcopy

# CV/ML
import cv2
import torch
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF
from PIL import Image
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

# 3D reconstruction
# import pvsac
import pycolmap

In [4]:
DEBUG = False

In [6]:
print('Kornia version', K.__version__)
print('Pycolmap version', pycolmap.__version__)

LOCAL_FEATURE = 'LoFTR'
device=torch.device('cuda')
# Can be LoFTR, KeyNetAffNetHardNet, or DISK

Kornia version 0.8.1
Pycolmap version 3.11.1


In [7]:
def arr_to_str(a):
    return ';'.join([str(x) for x in a.reshape(-1)])


def load_torch_image(fname, device=torch.device('cpu')):
    img = K.image_to_tensor(cv2.imread(fname), False).float() / 255.
    img = K.color.bgr_to_rgb(img.to(device))
    return img

In [16]:
# We will use ViT global descriptor to get matching shortlists.

def get_global_desc(fnames, model,
                    device =  torch.device('cpu')):
    model = model.eval()
    model= model.to(device)
    config = resolve_data_config({}, model=model)
    transform = create_transform(**config)
    global_descs_convnext=[]
    for i, img_fname_full in tqdm(enumerate(fnames),total= len(fnames)):
        key = os.path.splitext(os.path.basename(img_fname_full))[0]
#         img = Image.open(img_fname_full).convert('RGB')
        img = cv2.imread(img_fname_full)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (512, 512))
        img = Image.fromarray(img)
        timg = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            desc = model.forward_features(timg.to(device)).mean(dim=(-1,2))#
            #print (desc.shape)
            desc = desc.view(1, -1)
            desc_lr = model.forward_features(timg.flip(-1).to(device)).mean(dim=(-1,2))#
            #print (desc.shape)
            desc_lr = desc_lr.view(1, -1)
            desc_norm = F.normalize((desc+desc_lr)/2, dim=1, p=2)
        #print (desc_norm)
        global_descs_convnext.append(desc_norm.detach().cpu())
    global_descs_all = torch.cat(global_descs_convnext, dim=0)
    return global_descs_all


def get_img_pairs_exhaustive(img_fnames):
    index_pairs = []
    for i in range(len(img_fnames)):
        for j in range(i+1, len(img_fnames)):
            index_pairs.append((i,j))
    return index_pairs


def get_image_pairs_shortlist(fnames,
                              sim_th = 0.6, # should be strict
                              min_pairs = 20,
                              exhaustive_if_less = 20,
                              device=torch.device('cpu')):
    num_imgs = len(fnames)

    if num_imgs <= exhaustive_if_less:
        return get_img_pairs_exhaustive(fnames)
    
    model_path = ['/kaggle/input/tf-efficientnet-b7/tf_efficientnet_b7_ra-6c08e654.pth']
    model_name = ['tf_efficientnet_b7']
    descs_list = []
    for i in range(len(model_name)): 
        model = timm.create_model(model_name[i], 
                                  checkpoint_path=model_path[i])
        model.eval()
        descs = get_global_desc(fnames, model, device=device)
        descs_list.append(descs)
        
    descs = torch.cat(descs_list, dim=-1)
    print(descs.shape)
    dm = torch.cdist(descs, descs, p=2).detach().cpu().numpy()
    # removing half
    mask = dm <= sim_th
    total = 0
    matching_list = []
    ar = np.arange(num_imgs)
    already_there_set = []
    for st_idx in range(num_imgs-1):
        mask_idx = mask[st_idx]
        to_match = ar[mask_idx]
        if len(to_match) < min_pairs:
            to_match = np.argsort(dm[st_idx])[:min_pairs]  
        for idx in to_match:
            if st_idx == idx:
                continue
            if dm[st_idx, idx] < 1000:
                matching_list.append(tuple(sorted((st_idx, idx.item()))))
                total+=1
    matching_list = sorted(list(set(matching_list)))
    return matching_list

In [17]:
# Code to manipulate a colmap database.
# Forked from https://github.com/colmap/colmap/blob/dev/scripts/python/database.py

# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
#       its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)

# This script is based on an original implementation by True Price.

import sys
import sqlite3
import numpy as np


IS_PYTHON3 = sys.version_info[0] >= 3

MAX_IMAGE_ID = 2**31 - 1

CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
    camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    model INTEGER NOT NULL,
    width INTEGER NOT NULL,
    height INTEGER NOT NULL,
    params BLOB,
    prior_focal_length INTEGER NOT NULL)"""

CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""

CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
    image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    name TEXT NOT NULL UNIQUE,
    camera_id INTEGER NOT NULL,
    prior_qw REAL,
    prior_qx REAL,
    prior_qy REAL,
    prior_qz REAL,
    prior_tx REAL,
    prior_ty REAL,
    prior_tz REAL,
    CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
    FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
""".format(MAX_IMAGE_ID)

CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
CREATE TABLE IF NOT EXISTS two_view_geometries (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    config INTEGER NOT NULL,
    F BLOB,
    E BLOB,
    H BLOB)
"""

CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
"""

CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB)"""

CREATE_NAME_INDEX = \
    "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"

CREATE_ALL = "; ".join([
    CREATE_CAMERAS_TABLE,
    CREATE_IMAGES_TABLE,
    CREATE_KEYPOINTS_TABLE,
    CREATE_DESCRIPTORS_TABLE,
    CREATE_MATCHES_TABLE,
    CREATE_TWO_VIEW_GEOMETRIES_TABLE,
    CREATE_NAME_INDEX
])


def image_ids_to_pair_id(image_id1, image_id2):
    if image_id1 > image_id2:
        image_id1, image_id2 = image_id2, image_id1
    return image_id1 * MAX_IMAGE_ID + image_id2


def pair_id_to_image_ids(pair_id):
    image_id2 = pair_id % MAX_IMAGE_ID
    image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
    return image_id1, image_id2


def array_to_blob(array):
    if IS_PYTHON3:
        return array.tostring()
    else:
        return np.getbuffer(array)


def blob_to_array(blob, dtype, shape=(-1,)):
    if IS_PYTHON3:
        return np.fromstring(blob, dtype=dtype).reshape(*shape)
    else:
        return np.frombuffer(blob, dtype=dtype).reshape(*shape)


class COLMAPDatabase(sqlite3.Connection):

    @staticmethod
    def connect(database_path):
        return sqlite3.connect(database_path, factory=COLMAPDatabase)


    def __init__(self, *args, **kwargs):
        super(COLMAPDatabase, self).__init__(*args, **kwargs)

        self.create_tables = lambda: self.executescript(CREATE_ALL)
        self.create_cameras_table = \
            lambda: self.executescript(CREATE_CAMERAS_TABLE)
        self.create_descriptors_table = \
            lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
        self.create_images_table = \
            lambda: self.executescript(CREATE_IMAGES_TABLE)
        self.create_two_view_geometries_table = \
            lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
        self.create_keypoints_table = \
            lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
        self.create_matches_table = \
            lambda: self.executescript(CREATE_MATCHES_TABLE)
        self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)

    def add_camera(self, model, width, height, params,
                   prior_focal_length=False, camera_id=None):
        params = np.asarray(params, np.float64)
        cursor = self.execute(
            "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
            (camera_id, model, width, height, array_to_blob(params),
             prior_focal_length))
        return cursor.lastrowid

    def add_image(self, name, camera_id,
                  prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
        cursor = self.execute(
            "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
             prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
        return cursor.lastrowid

    def add_keypoints(self, image_id, keypoints):
        assert(len(keypoints.shape) == 2)
        assert(keypoints.shape[1] in [2, 4, 6])

        keypoints = np.asarray(keypoints, np.float32)
        self.execute(
            "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
            (image_id,) + keypoints.shape + (array_to_blob(keypoints),))

    def add_descriptors(self, image_id, descriptors):
        descriptors = np.ascontiguousarray(descriptors, np.uint8)
        self.execute(
            "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
            (image_id,) + descriptors.shape + (array_to_blob(descriptors),))

    def add_matches(self, image_id1, image_id2, matches):
        assert(len(matches.shape) == 2)
        assert(matches.shape[1] == 2)

        if image_id1 > image_id2:
            matches = matches[:,::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        self.execute(
            "INSERT INTO matches VALUES (?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches),))

    def add_two_view_geometry(self, image_id1, image_id2, matches,
                              F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
        assert(len(matches.shape) == 2)
        assert(matches.shape[1] == 2)

        if image_id1 > image_id2:
            matches = matches[:,::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        F = np.asarray(F, dtype=np.float64)
        E = np.asarray(E, dtype=np.float64)
        H = np.asarray(H, dtype=np.float64)
        self.execute(
            "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches), config,
             array_to_blob(F), array_to_blob(E), array_to_blob(H)))

In [19]:
# Code to interface DISK with Colmap.
# Forked from https://github.com/cvlab-epfl/disk/blob/37f1f7e971cea3055bb5ccfc4cf28bfd643fa339/colmap/h5_to_db.py

#  Copyright [2020] [Michał Tyszkiewicz, Pascal Fua, Eduard Trulls]
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import os, argparse, h5py, warnings
import numpy as np
from tqdm import tqdm
from PIL import Image, ExifTags


def get_focal(image_path, err_on_default=False):
    image         = Image.open(image_path)
    max_size      = max(image.size)

    exif = image.getexif()
    focal = None
    if exif is not None:
        focal_35mm = None
        # https://github.com/colmap/colmap/blob/d3a29e203ab69e91eda938d6e56e1c7339d62a99/src/util/bitmap.cc#L299
        for tag, value in exif.items():
            focal_35mm = None
            if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm':
                focal_35mm = float(value)
                break

        if focal_35mm is not None:
            focal = focal_35mm / 35. * max_size
    
    if focal is None:
        if err_on_default:
            raise RuntimeError("Failed to find focal length")

        # failed to find it in exif, use prior
        FOCAL_PRIOR = 1.2
        focal = FOCAL_PRIOR * max_size

    return focal

def create_camera(db, image_path, camera_model):
    image         = Image.open(image_path)
    width, height = image.size

    focal = get_focal(image_path)

    if camera_model == 'simple-pinhole':
        model = 0 # simple pinhole
        param_arr = np.array([focal, width / 2, height / 2])
    if camera_model == 'pinhole':
        model = 1 # pinhole
        param_arr = np.array([focal, focal, width / 2, height / 2])
    elif camera_model == 'simple-radial':
        model = 2 # simple radial
        param_arr = np.array([focal, width / 2, height / 2, 0.1])
    elif camera_model == 'opencv':
        model = 4 # opencv
        param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.])
         
    return db.add_camera(model, width, height, param_arr)


def add_keypoints(db, h5_path, image_path, img_ext, camera_model, single_camera = True):
    keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r')

    camera_id = None
    fname_to_id = {}
    for filename in tqdm(list(keypoint_f.keys())):
        keypoints = keypoint_f[filename][()]

        fname_with_ext = filename# + img_ext
        path = os.path.join(image_path, fname_with_ext)
        if not os.path.isfile(path):
            raise IOError(f'Invalid image path {path}')

        if camera_id is None or not single_camera:
            camera_id = create_camera(db, path, camera_model)
        image_id = db.add_image(fname_with_ext, camera_id)
        fname_to_id[filename] = image_id

        db.add_keypoints(image_id, keypoints)

    return fname_to_id

def add_matches(db, h5_path, fname_to_id):
    match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r')
    
    added = set()
    n_keys = len(match_file.keys())
    n_total = (n_keys * (n_keys - 1)) // 2

    with tqdm(total=n_total) as pbar:
        for key_1 in match_file.keys():
            group = match_file[key_1]
            for key_2 in group.keys():
                id_1 = fname_to_id[key_1]
                id_2 = fname_to_id[key_2]

                pair_id = image_ids_to_pair_id(id_1, id_2)
                if pair_id in added:
                    warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!')
                    continue
            
                matches = group[key_2][()]
                db.add_matches(id_1, id_2, matches)

                added.add(pair_id)

                pbar.update(1)

In [20]:
# Making kornia local features loading w/o internet
class KeyNetAffNetHardNet(KF.LocalFeature):
    """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor.

    .. image:: _static/img/keynet_affnet.jpg
    """

    def __init__(
        self,
        num_features: int = 5000,
        upright: bool = False,
        device = torch.device('cpu'),
        scale_laf: float = 1.0,
    ):
        ori_module = KF.PassLAF() if upright else KF.LAFOrienter(angle_detector=KF.OriNet(False)).eval()
        if not upright:
            weights = torch.load('/kaggle/input/kornia-local-feature-weights/OriNet.pth')['state_dict']
            ori_module.angle_detector.load_state_dict(weights)
        detector = KF.KeyNetDetector(
            False, num_features=num_features, ori_module=ori_module, aff_module=KF.LAFAffNetShapeEstimator(False).eval()
        ).to(device)
        kn_weights = torch.load('/kaggle/input/kornia-local-feature-weights/keynet_pytorch.pth')['state_dict']
        detector.model.load_state_dict(kn_weights)
        affnet_weights = torch.load('/kaggle/input/kornia-local-feature-weights/AffNet.pth')['state_dict']
        detector.aff.load_state_dict(affnet_weights)
        
        hardnet = KF.HardNet(False).eval()
        hn_weights = torch.load('/kaggle/input/kornia-local-feature-weights/HardNetLib.pth')['state_dict']
        hardnet.load_state_dict(hn_weights)
        descriptor = KF.LAFDescriptor(hardnet, patch_size=32, grayscale_descriptor=True).to(device)
        super().__init__(detector, descriptor, scale_laf)

In [21]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import gc


import torch
if not torch.cuda.is_available():
    print('You may want to enable the GPU switch?')

INSTALLED_LOG = {}

You may want to enable the GPU switch?


In [None]:
# Install superglue
force_superglue_reinstall = False

if 'superglue' not in INSTALLED_LOG or force_superglue_reinstall:
    !mkdir /tmp/superpoint
    !cp -r ../input/super-glue-pretrained-network/models /tmp/superpoint/superpoint
    !ls /tmp/superpoint/superpoint
    !touch /tmp/superpoint/superpoint/__init__.py
    INSTALLED_LOG['superglue'] = True
else:
    print('Already installed SuperGlue. Set "force_superglue_reinstall=True" to override this behavior.')

    # https://www.kaggle.com/datasets/losveria/super-glue-pretrained-network !!! SOURCE OF SUPERPOINT AND SUPERGLUE

# Import superglue
import sys
sys.path.append("/tmp/superpoint")
from superpoint.superpoint import SuperPoint
from superpoint.superglue import SuperGlue


class SuperGlueCustomMatchingV2(torch.nn.Module):
    """ Image Matching Frontend (SuperPoint + SuperGlue) """
    def __init__(self, config={}, device=None):
        super().__init__()
        self.superpoint = SuperPoint(config.get('superpoint', {}))
        self.superglue = SuperGlue(config.get('superglue', {}))

        self.tta_map = {
            'orig': self.untta_none,
            'eqhist': self.untta_none,
            'clahe': self.untta_none,
            'flip_lr': self.untta_fliplr,
            'flip_ud': self.untta_flipud,
            'rot_r10': self.untta_rotr10,
            'rot_l10': self.untta_rotl10,
            'fliplr_rotr10': self.untta_fliplr_rotr10,
            'fliplr_rotl10': self.untta_fliplr_rotl10
        }
        self.device = device

    def forward_flat(self, data, ttas=['orig', ], tta_groups=[['orig']]):
        """ Run SuperPoint (optionally) and SuperGlue
        SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
        Args:
          data: dictionary with minimal keys: ['image0', 'image1']
        """
        pred = {}

        # Extract SuperPoint (keypoints, scores, descriptors) if not provided
        # sp_st = time.time()
        if 'keypoints0' not in data:
            pred0 = self.superpoint({'image': data['image0']})
            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
        if 'keypoints1' not in data:
            pred1 = self.superpoint({'image': data['image1']})
            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
        # sp_nd = time.time()
        # print('SP:', sp_nd - sp_st, 's')

        # Reverse-tta before inference
        pred['scores0'] = list(pred['scores0'])
        pred['scores1'] = list(pred['scores1'])
        for i in range(len(pred['keypoints0'])):
            pred['keypoints0'][i], pred['descriptors0'][i], pred['scores0'][i] = self.tta_map[ttas[i]](
                pred['keypoints0'][i], pred['descriptors0'][i], pred['scores0'][i],
                w=data['image0'].shape[3], h=data['image0'].shape[2], inplace=True, mask_illegal=True)

            pred['keypoints1'][i], pred['descriptors1'][i], pred['scores1'][i] = self.tta_map[ttas[i]](
                pred['keypoints1'][i], pred['descriptors1'][i], pred['scores1'][i],
                w=data['image1'].shape[3], h=data['image1'].shape[2], inplace=True, mask_illegal=True)

        # Batch all features
        # We should either have i) one image per batch, or
        # ii) the same number of local features for all images in the batch.
        data = {**data, **pred}

        group_preds = []
        for tta_group in tta_groups:
            group_mask = torch.from_numpy(np.array([x in tta_group for x in ttas], dtype=np.bool))
            group_data = {
                **{f'keypoints{k}': [data[f'keypoints{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'descriptors{k}': [data[f'descriptors{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'scores{k}': [data[f'scores{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'image{k}': data[f'image{k}'][group_mask, ...] for k in [0, 1]},
            }
            for k, v in group_data.items():
                if isinstance(group_data[k], (list, tuple)):
                    if k.startswith('descriptor'):
                        group_data[k] = torch.cat(group_data[k], 1)[None, ...]
                    else:
                        group_data[k] = torch.cat(group_data[k])[None, ...]
                else:
                    group_data[k] = torch.flatten(group_data[k], 0, 1)[None, ...]
            # sg_st = time.time()
            group_pred = {
                # **{k: group_data[k] for k in group_data},
                **group_data,
                **self.superglue(group_data)
            }
            # sg_nd = time.time()
            # print('SG:', sg_nd - sg_st, 's')
            group_preds.append(group_pred)
        return group_preds

    def forward_cross(self, data, ttas=['orig', ], tta_groups=[('orig', 'orig')]):
        pred = {}

        # Extract SuperPoint (keypoints, scores, descriptors) if not provided
        sp_st = time()
        if 'keypoints0' not in data:
            pred0 = self.superpoint({'image': data['image0']})
            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
        if 'keypoints1' not in data:
            pred1 = self.superpoint({'image': data['image1']})
            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
        sp_nd = time()

        # Batch all features
        # We should either have i) one image per batch, or
        # ii) the same number of local features for all images in the batch.
        data = {**data, **pred}

        # Group predictions (list, with elements with matches{0,1}, matching_scores{0,1} keys)
        group_pred_list = []
        tta2id = {k: i for i, k in enumerate(ttas)}
        for tta_group in tta_groups:
            group_idx = tta2id[tta_group[0]], tta2id[tta_group[1]]
            group_data = {
                **{f'image{i}': data[f'image{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'keypoints{i}': data[f'keypoints{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'descriptors{i}': data[f'descriptors{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'scores{i}': data[f'scores{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
            }

            for k in group_data:
                if isinstance(group_data[k], (list, tuple)):
                    group_data[k] = torch.stack(group_data[k])

            group_sg_pred = self.superglue(group_data)
            group_pred_list.append(group_sg_pred)

        # UnTTA
        data['scores0'] = list(data['scores0'])
        data['scores1'] = list(data['scores1'])
        for i in range(len(data['keypoints0'])):
            data['keypoints0'][i], data['descriptors0'][i], data['scores0'][i] = self.tta_map[ttas[i]](
                data['keypoints0'][i], data['descriptors0'][i], data['scores0'][i],
                w=data['image0'].shape[3], h=data['image0'].shape[2], inplace=True, mask_illegal=False)

            data['keypoints1'][i], data['descriptors1'][i], data['scores1'][i] = self.tta_map[ttas[i]](
                data['keypoints1'][i], data['descriptors1'][i], data['scores1'][i],
                w=data['image1'].shape[3], h=data['image1'].shape[2], inplace=True, mask_illegal=False)

        # Sooo... groups?
        for group_pred, tta_group in zip(group_pred_list, tta_groups):
            group_idx = tta2id[tta_group[0]], tta2id[tta_group[1]]
            group_pred.update({
                **{f'keypoints{i}': data[f'keypoints{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'scores{i}': data[f'scores{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
            })
        return group_pred_list


    def untta_none(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        return keypoints, descriptors, scores
    
    def untta_fliplr(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        keypoints[:, 0] = w - keypoints[:, 0] - 1.
        return keypoints, descriptors, scores

    def untta_flipud(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        keypoints[:, 1] = h - keypoints[:, 1] - 1.
        return keypoints, descriptors, scores

    def untta_rotr10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is +10, inverse is -10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), -15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores

    def untta_rotl10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is -10, inverse is +10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), 15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores
        
    def untta_fliplr_rotr10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is +10, inverse is -10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), -15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        rot_kpts[:, 0] = w - rot_kpts[:, 0] - 1.
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores

    def untta_fliplr_rotl10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is -10, inverse is +10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), 15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        rot_kpts[:, 0] = w - rot_kpts[:, 0] - 1.
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores


class SuperGlueMatcherV2:
    def __init__(self, config, device=None, conf_th=None):
        self.config = config
        self.device = device
        self._superglue_matcher = SuperGlueCustomMatchingV2(
            config=config, device=self.device,
            ).eval().to(device)
        self.conf_thresh = conf_th
    
    def prep_np_img(self, img, long_side=None):
        if long_side is not None:
            scale = long_side / max(img.shape[0], img.shape[1])
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
            img = cv2.resize(img, (w, h))
        else:
            scale = 1.0
        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), scale
    
    def frame2tensor(self, frame):
        return (torch.from_numpy(frame).float()/255.)[None, None].to(self.device)
            
    def tta_rotation_preprocess(self, img_np, angle):
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        rot_img = self.frame2tensor(cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0])))
        return rot_M, rot_img, rot_M_inv

    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask
# 
    def __call__(self, img_np0, img_np1, input_longside, tta_groups=[('orig', 'orig')], forward_type='cross'):
        with torch.no_grad():
            img_np0, scale0 = self.prep_np_img(img_np0, input_longside)
            img_np1, scale1 = self.prep_np_img(img_np1, input_longside)

            img_ts0 = self.frame2tensor(img_np0)
            img_ts1 = self.frame2tensor(img_np1)
            images0, images1 = [], []

            tta = []
            for tta_g in tta_groups:
                tta += tta_g
            tta = list(set(tta))

            # TTA
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_ts0, img_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_ts1, [2, ])
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np0, 15)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np1, 15)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np0, -15)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np1, -15)
                elif tta_elem == 'fliplr_rotr10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np0[:, ::-1], 15)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np1[:, ::-1], 15)
                elif tta_elem == 'fliplr_rotl10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np0[:, ::-1], -15)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np1[:, ::-1], -15)
                elif tta_elem == 'eqhist':
                    img_ts0_aug = self.frame2tensor(cv2.equalizeHist(img_np0))
                    img_ts1_aug = self.frame2tensor(cv2.equalizeHist(img_np1))
                elif tta_elem == 'clahe':
                    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
                    img_ts0_aug = self.frame2tensor(clahe.apply(img_np0))
                    img_ts1_aug = self.frame2tensor(clahe.apply(img_np1))
                else:
                    raise ValueError('Unknown TTA method.')

                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            # Inference
            if forward_type == 'cross':
                pred = self._superglue_matcher.forward_cross(
                    data={
                        "image0": torch.cat(images0),
                        "image1": torch.cat(images1)
                    },
                    ttas=tta, tta_groups=tta_groups)
            elif forward_type == 'flat':
                pred = self._superglue_matcher.forward_flat(
                data={
                    "image0": torch.cat(images0),
                    "image1": torch.cat(images1)
                },
                ttas=tta, tta_groups=tta_groups)
            else:
                raise RuntimeError(f'Unknown forward_type {forward_type}')

            mkpts0, mkpts1, mconf = [], [], []
            for group_pred in pred:
                pred_aug = {k: v[0].detach().cpu().numpy().squeeze() for k, v in group_pred.items()}
                kpts0, kpts1 = pred_aug["keypoints0"], pred_aug["keypoints1"]
                matches, conf = pred_aug["matches0"], pred_aug["matching_scores0"]

                if self.conf_thresh is None:
                    valid = matches > -1
                else:
                    valid = (matches > -1) & (conf >= self.conf_thresh)
                mkpts0.append(kpts0[valid])
                mkpts1.append(kpts1[matches[valid]])
                mconf.append(conf[valid])

            cat_mkpts0 = np.concatenate(mkpts0)
            cat_mkpts1 = np.concatenate(mkpts1)
            mask0 = (cat_mkpts0[:, 0] >= 0) & (cat_mkpts0[:, 0] < img_np0.shape[1]) & (cat_mkpts0[:, 1] >= 0) & (cat_mkpts0[:, 1] < img_np0.shape[0])
            mask1 = (cat_mkpts1[:, 0] >= 0) & (cat_mkpts1[:, 0] < img_np1.shape[1]) & (cat_mkpts1[:, 1] >= 0) & (cat_mkpts1[:, 1] < img_np1.shape[0])
            return cat_mkpts0[mask0 & mask1] / scale0, cat_mkpts1[mask0 & mask1] / scale1


# 1600 is the validation size in the paper


302.96s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


mkdir: /tmp/superpoint: File exists


308.41s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


cp: ../input/super-glue-pretrained-network/models: No such file or directory


313.82s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


ls: /tmp/superpoint/superpoint: No such file or directory


319.22s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


touch: /tmp/superpoint/superpoint/__init__.py: No such file or directory


ModuleNotFoundError: No module named 'superpoint'

In [None]:
class LoFTRMatcher:
    def __init__(self, device=None, input_longside=1200, conf_th=None):
        self._loftr_matcher = KF.LoFTR(pretrained=None)
        self._loftr_matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
        self._loftr_matcher = self._loftr_matcher.to(device).eval()
        self.device = device
        self.conf_thresh = conf_th
        
    def prep_img(self, img, long_side=1200):
        if long_side is not None:
            scale = long_side / max(img.shape[0], img.shape[1]) 
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
            img = cv2.resize(img, (w, h))
        else:
            scale = 1.0

        img_ts = K.image_to_tensor(img, False).float() / 255.
        img_ts = K.color.bgr_to_rgb(img_ts)
        img_ts = K.color.rgb_to_grayscale(img_ts)
        return img, img_ts.to(self.device), scale
    
    def tta_rotation_preprocess(self, img_np, angle):
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        rot_img = cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0]))

        rot_img_ts = K.image_to_tensor(rot_img, False).float() / 255.
        rot_img_ts = K.color.bgr_to_rgb(rot_img_ts)
        rot_img_ts = K.color.rgb_to_grayscale(rot_img_ts)
        return rot_M, rot_img_ts.to(self.device), rot_M_inv

    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask
# 
    def __call__(self, img_np1, img_np2, input_longside, tta=['orig', 'flip_lr']):
        with torch.no_grad():
            img_np1, img_ts0, scale0 = self.prep_img(img_np1, input_longside)
            img_np2, img_ts1, scale1 = self.prep_img(img_np2, input_longside)
            images0, images1 = [], []

            # TTA
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_ts0, img_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_ts1, [2, ])
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np1, 10)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np2, 10)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np1, -10)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np2, -10)
                else:
                    raise ValueError('Unknown TTA method.')
                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            # Inference
            input_dict = {"image0": torch.cat(images0), "image1": torch.cat(images1)}
            correspondences = self._loftr_matcher(input_dict)
            mkpts0 = correspondences['keypoints0'].cpu().numpy()
            mkpts1 = correspondences['keypoints1'].cpu().numpy()
            batch_id = correspondences['batch_indexes'].cpu().numpy()
            confidence = correspondences['confidence'].cpu().numpy()

            # Reverse TTA
            for idx, tta_elem in enumerate(tta):
                batch_mask = batch_id == idx

                if tta_elem == 'orig':
                    pass
                elif tta_elem == 'flip_lr':
                    mkpts0[batch_mask, 0] = img_np1.shape[1] - mkpts0[batch_mask, 0]
                    mkpts1[batch_mask, 0] = img_np2.shape[1] - mkpts1[batch_mask, 0]
                elif tta_elem == 'flip_ud':
                    mkpts0[batch_mask, 1] = img_np1.shape[0] - mkpts0[batch_mask, 1]
                    mkpts1[batch_mask, 1] = img_np2.shape[0] - mkpts1[batch_mask, 1]
                elif tta_elem == 'rot_r10':
                    mkpts0[batch_mask], mask0 = self.tta_rotation_postprocess(mkpts0[batch_mask], img_np1, rot_r10_M0_inv)
                    mkpts1[batch_mask], mask1 = self.tta_rotation_postprocess(mkpts1[batch_mask], img_np2, rot_r10_M1_inv)
                    confidence[batch_mask] += (~(mask0 & mask1)).astype(np.float32) * -10.
                elif tta_elem == 'rot_l10':
                    mkpts0[batch_mask], mask0 = self.tta_rotation_postprocess(mkpts0[batch_mask], img_np1, rot_l10_M0_inv)
                    mkpts1[batch_mask], mask1 = self.tta_rotation_postprocess(mkpts1[batch_mask], img_np2, rot_l10_M1_inv)
                    confidence[batch_mask] += (~(mask0 & mask1)).astype(np.float32) * -10.
                else:
                    raise ValueError('Unknown TTA method.')
                    
            if self.conf_thresh is not None:
                th_mask = confidence >= self.conf_thresh
            else:
                th_mask = confidence >= 0.
            mkpts0, mkpts1 = mkpts0[th_mask, :], mkpts1[th_mask, :]

            # Matching points
            return mkpts0 / scale0, mkpts1 / scale1

In [None]:
base_config = {
    "superpoint": {
        "nms_radius": 3,
        "keypoint_threshold": 0.005,
        "max_keypoints": 2048,
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 100,
        "match_threshold": 0.2,
    }
}
f8000_config = {
    "superpoint": {
        "nms_radius": 3,
        "keypoint_threshold": 0.005,
        "max_keypoints": 2048*4,
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 100,
        "match_threshold": 0.2,
    }
}
superglue_matcher = SuperGlueMatcherV2(base_config, device=device, conf_th=0.2)     
superglue_matcher_8096 = SuperGlueMatcherV2(f8000_config, device=device, conf_th=0.2)     
loftr_matcher = LoFTRMatcher(device=device, conf_th=0.3)

Loaded SuperPoint model
Loaded SuperGlue model ("outdoor" weights)
Loaded SuperPoint model
Loaded SuperGlue model ("outdoor" weights)


In [None]:
def detect_features(img_fnames,
                    num_feats = 2048,
                    upright = False,
                    device=torch.device('cpu'),
                    feature_dir = '.featureout',
                    resize_small_edge_to = 600):

    feature = KeyNetAffNetHardNet(num_feats, upright, device).to(device).eval()
    if not os.path.isdir(feature_dir):
        os.makedirs(feature_dir)
    with h5py.File(f'{feature_dir}/lafs.h5', mode='w') as f_laf, \
         h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='w') as f_desc:
        for img_path in progress_bar(img_fnames):
            img_fname = img_path.split('/')[-1]
            key = img_fname
            with torch.inference_mode():
                timg = load_torch_image(img_path, device=device)
                H, W = timg.shape[2:]
                if resize_small_edge_to is None:
                    timg_resized = timg
                else:
                    timg_resized = K.geometry.resize(timg, resize_small_edge_to, antialias=True)
#                     print(f'Resized {timg.shape} to {timg_resized.shape} (resize_small_edge_to={resize_small_edge_to})')
                h, w = timg_resized.shape[2:]

                lafs, resps, descs = feature(K.color.rgb_to_grayscale(timg_resized))
                lafs[:,:,0,:] *= float(W) / float(w)
                lafs[:,:,1,:] *= float(H) / float(h)
                desc_dim = descs.shape[-1]
                kpts = KF.get_laf_center(lafs).reshape(-1, 2).detach().cpu().numpy()
                descs = descs.reshape(-1, desc_dim).detach().cpu().numpy()
                f_laf[key] = lafs.detach().cpu().numpy()
                f_kp[key] = kpts
                f_desc[key] = descs
    return

def get_unique_idxs(A, dim=0):
    # https://stackoverflow.com/questions/72001505/how-to-get-unique-elements-and-their-firstly-appeared-indices-of-a-pytorch-tenso
    unique, idx, counts = torch.unique(A, dim=dim, sorted=True, return_inverse=True, return_counts=True)
    _, ind_sorted = torch.sort(idx, stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((torch.tensor([0],device=cum_sum.device), cum_sum[:-1]))
    first_indices = ind_sorted[cum_sum]
    return first_indices

def match_features(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout',
                   device=torch.device('cpu'),
                   min_matches=15, 
                   force_mutual = True,
                   matching_alg='smnn'
                  ):
    assert matching_alg in ['smnn', 'adalam']
    with h5py.File(f'{feature_dir}/lafs.h5', mode='r') as f_laf, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='r') as f_desc, \
        h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:

        for pair_idx in progress_bar(index_pairs):
                    idx1, idx2 = pair_idx
                    fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
                    key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]
                    lafs1 = torch.from_numpy(f_laf[key1][...]).to(device)
                    lafs2 = torch.from_numpy(f_laf[key2][...]).to(device)
                    desc1 = torch.from_numpy(f_desc[key1][...]).to(device)
                    desc2 = torch.from_numpy(f_desc[key2][...]).to(device)
                    if matching_alg == 'adalam':
                        img1, img2 = cv2.imread(fname1), cv2.imread(fname2)
                        hw1, hw2 = img1.shape[:2], img2.shape[:2]
                        adalam_config = KF.adalam.get_adalam_default_config()
                        #adalam_config['orientation_difference_threshold'] = None
                        #adalam_config['scale_rate_threshold'] = None
                        adalam_config['force_seed_mnn']= False
                        adalam_config['search_expansion'] = 16
                        adalam_config['ransac_iters'] = 128
                        adalam_config['device'] = device
                        dists, idxs = KF.match_adalam(desc1, desc2,
                                                      lafs1, lafs2, # Adalam takes into account also geometric information
                                                      hw1=hw1, hw2=hw2,
                                                      config=adalam_config) # Adalam also benefits from knowing image size
                    else:
                        dists, idxs = KF.match_smnn(desc1, desc2, 0.9)
                    if len(idxs)  == 0:
                        continue
                    # Force mutual nearest neighbors
                    if force_mutual:
                        first_indices = get_unique_idxs(idxs[:,1])
                        idxs = idxs[first_indices]
                        dists = dists[first_indices]
                    n_matches = len(idxs)
                    if False:
                        print (f'{key1}-{key2}: {n_matches} matches')
                    group  = f_match.require_group(key1)
                    if n_matches >= min_matches:
                         group.create_dataset(key2, data=idxs.detach().cpu().numpy().reshape(-1, 2))
    return

def calc_roi_coords(w, h, mkpts, crop_roi_min_side=100, margin=50):
    """
    mkpts is a list [(x1, y1), (x2, y2), ...],
    where x is w axis and y is h axis.
    """
    def wiggle(a, b, minl, bound):
        """evenly pad a and b so that [a, b) have length minl"""
        if b < a:
            a, b = b, a
        if minl >= bound:
            return 0, bound
        if (b - a) >= minl:
            return a, b
        d = (minl - (b - a))
        pad_l, pad_r = d // 2, d - d // 2
        na, nb = a - pad_l, b + pad_r
        if a < 0:
            a, b = 0, minl
        if b >= bound:
            a, b = bound - minl, bound
        return a, b

    
    left, right = int(np.floor(mkpts[:, 0].min())), int(np.ceil(mkpts[:, 0].max()))
    top, bottom = int(np.floor(mkpts[:, 1].min())), int(np.ceil(mkpts[:, 1].max()))
    left, right = wiggle(left, right, crop_roi_min_side, w)
    top, bottom = wiggle(top, bottom, crop_roi_min_side, h)
    left, right = max(0, left - margin), min(right + margin, w)
    top, bottom = max(0, top - margin), min(bottom + margin, h)
    return left, right, top, bottom

def match_loftr_superglue(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout_loftr',
                   device=torch.device('cuda'),
                   min_matches=15, resize_to_ = (640, 480)):

    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='w') as f_match:
        for pair_idx in progress_bar(index_pairs):
            idx1, idx2 = pair_idx
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]

            img1 = cv2.imread(fname1)
            img2 = cv2.imread(fname2)
            
            mkpts0_loftr, mkpts1_loftr = loftr_matcher(img1, img2, 1024)
            mkpts1_loftr_lr, mkpts0_loftr_lr = loftr_matcher(img2, img1, 1024)
            mkpts0_superglue_1024, mkpts1_superglue_1024 = superglue_matcher(img1, img2, 1024, tta_groups=[('orig', 'orig'),('flip_lr', 'flip_lr')])
            mkpts0_superglue_1440, mkpts1_superglue_1440 = superglue_matcher(img1, img2, 1440, tta_groups=[('orig', 'orig'),('flip_lr', 'flip_lr')])
            
            mkpts0 = np.concatenate([mkpts0_loftr,mkpts0_superglue_1024,mkpts0_superglue_1440,mkpts0_loftr_lr], axis=0)
            mkpts1 = np.concatenate([mkpts1_loftr,mkpts1_superglue_1024,mkpts1_superglue_1440,mkpts1_loftr_lr], axis=0)
            
            n_matches = len(mkpts1)
            group  = f_match.require_group(key1)
            if n_matches >= min_matches:
                 group.create_dataset(key2, data=np.concatenate([mkpts0, mkpts1], axis=1))
#     clean_index_pairs = [(i,j) for pi, (i,j) in enumerate(index_pairs) if pi not in drop_pair]
    # Let's find unique loftr pixels and group them together.
    kpts = defaultdict(list)
    match_indexes = defaultdict(dict)
    total_kpts=defaultdict(int)
    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='r') as f_match:
        for k1 in f_match.keys():
            group  = f_match[k1]
            for k2 in group.keys():
                matches = group[k2][...]
                total_kpts[k1]
                kpts[k1].append(matches[:, :2])
                kpts[k2].append(matches[:, 2:])
                current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
                current_match[:, 0]+=total_kpts[k1]
                current_match[:, 1]+=total_kpts[k2]
                total_kpts[k1]+=len(matches)
                total_kpts[k2]+=len(matches)
                match_indexes[k1][k2]=current_match

    for k in kpts.keys():
        kpts[k] = np.round(np.concatenate(kpts[k], axis=0))
    unique_kpts = {}
    unique_match_idxs = {}
    out_match = defaultdict(dict)
    for k in kpts.keys():
        uniq_kps, uniq_reverse_idxs = torch.unique(torch.from_numpy(kpts[k]),dim=0, return_inverse=True)
        unique_match_idxs[k] = uniq_reverse_idxs
        unique_kpts[k] = uniq_kps.numpy()
    for k1, group in match_indexes.items():
        for k2, m in group.items():
            m2 = deepcopy(m)
            m2[:,0] = unique_match_idxs[k1][m2[:,0]]
            m2[:,1] = unique_match_idxs[k2][m2[:,1]]
            mkpts = np.concatenate([unique_kpts[k1][ m2[:,0]],
                                    unique_kpts[k2][  m2[:,1]],
                                   ],
                                   axis=1)
            unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0)
            m2_semiclean = m2[unique_idxs_current]
            unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
            m2_semiclean = m2_semiclean[unique_idxs_current1]
            unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
            m2_semiclean2 = m2_semiclean[unique_idxs_current2]
            out_match[k1][k2] = m2_semiclean2.numpy()
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp:
        for k, kpts1 in unique_kpts.items():
            f_kp[k] = kpts1
    
    with h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:
        for k1, gr in out_match.items():
            group  = f_match.require_group(k1)
            for k2, match in gr.items():
                group[k2] = match
    return 

def match_superglue(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout_loftr',
                   device=torch.device('cuda'),
                   min_matches=15, resize_to_ = (640, 480)):

    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='w') as f_match:
        for pair_idx in progress_bar(index_pairs):
            idx1, idx2 = pair_idx
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]

            img1 = cv2.imread(fname1)
            img2 = cv2.imread(fname2)
            
            mkpts0_superglue_1024, mkpts1_superglue_1024 = superglue_matcher_8096(img1, img2, 1024)
            mkpts0_superglue_1440, mkpts1_superglue_1440 = superglue_matcher_8096(img1, img2, 1440)
            
            mkpts0 = np.concatenate([mkpts0_superglue_1024,mkpts0_superglue_1440], axis=0)
            mkpts1 = np.concatenate([mkpts1_superglue_1024,mkpts1_superglue_1440], axis=0)
            
            n_matches = len(mkpts1)
            group  = f_match.require_group(key1)
            if n_matches >= min_matches:
                 group.create_dataset(key2, data=np.concatenate([mkpts0, mkpts1], axis=1))
#     clean_index_pairs = [(i,j) for pi, (i,j) in enumerate(index_pairs) if pi not in drop_pair]
    # Let's find unique loftr pixels and group them together.
    kpts = defaultdict(list)
    match_indexes = defaultdict(dict)
    total_kpts=defaultdict(int)
    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='r') as f_match:
        for k1 in f_match.keys():
            group  = f_match[k1]
            for k2 in group.keys():
                matches = group[k2][...]
                total_kpts[k1]
                kpts[k1].append(matches[:, :2])
                kpts[k2].append(matches[:, 2:])
                current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
                current_match[:, 0]+=total_kpts[k1]
                current_match[:, 1]+=total_kpts[k2]
                total_kpts[k1]+=len(matches)
                total_kpts[k2]+=len(matches)
                match_indexes[k1][k2]=current_match

    for k in kpts.keys():
        kpts[k] = np.round(np.concatenate(kpts[k], axis=0))
    unique_kpts = {}
    unique_match_idxs = {}
    out_match = defaultdict(dict)
    for k in kpts.keys():
        uniq_kps, uniq_reverse_idxs = torch.unique(torch.from_numpy(kpts[k]),dim=0, return_inverse=True)
        unique_match_idxs[k] = uniq_reverse_idxs
        unique_kpts[k] = uniq_kps.numpy()
    for k1, group in match_indexes.items():
        for k2, m in group.items():
            m2 = deepcopy(m)
            m2[:,0] = unique_match_idxs[k1][m2[:,0]]
            m2[:,1] = unique_match_idxs[k2][m2[:,1]]
            mkpts = np.concatenate([unique_kpts[k1][ m2[:,0]],
                                    unique_kpts[k2][  m2[:,1]],
                                   ],
                                   axis=1)
            unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0)
            m2_semiclean = m2[unique_idxs_current]
            unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
            m2_semiclean = m2_semiclean[unique_idxs_current1]
            unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
            m2_semiclean2 = m2_semiclean[unique_idxs_current2]
            out_match[k1][k2] = m2_semiclean2.numpy()
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp:
        for k, kpts1 in unique_kpts.items():
            f_kp[k] = kpts1
    
    with h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:
        for k1, gr in out_match.items():
            group  = f_match.require_group(k1)
            for k2, match in gr.items():
                group[k2] = match
    return 

def import_into_colmap(img_dir,
                       feature_dir ='.featureout',
                       database_path = 'colmap.db',
                       img_ext='.jpg'):
    db = COLMAPDatabase.connect(database_path)
    db.create_tables()
    single_camera = False
    fname_to_id = add_keypoints(db, feature_dir, img_dir, img_ext, 'simple-radial', single_camera)
    add_matches(
        db,
        feature_dir,
        fname_to_id,
    )

    db.commit()
    return

In [None]:
src = '/kaggle/input/image-matching-challenge-2023'
if DEBUG:
    # Get data from csv.
    csv_file = 'train/train_labels.csv'
else:
    csv_file = 'sample_submission.csv'
data_dict = {}
with open(f'{src}/{csv_file}', 'r') as f:
    for i, l in enumerate(f):
        # Skip header.
        if l and i > 0:
            if DEBUG:
                dataset, scene, image,  _, _ = l.strip().split(',')
            else:
                image, dataset, scene,  _, _ = l.strip().split(',')
            if dataset not in data_dict:
                data_dict[dataset] = {}
            if scene not in data_dict[dataset]:
                data_dict[dataset][scene] = []
            data_dict[dataset][scene].append(image)

In [None]:
for dataset in data_dict:
    for scene in data_dict[dataset]:
        print(f'{dataset} / {scene} -> {len(data_dict[dataset][scene])} images')

2cfa01ab573141e4 / 2fa124afd1f74f38 -> 3 images


In [None]:
out_results = {}
timings = {"shortlisting":[],
           "feature_detection": [],
           "feature_matching":[],
           "RANSAC": [],
           "Reconstruction": []}

In [None]:
# Function to create a submission file.
def create_submission(out_results, data_dict):
#     print(out_results,data_dict)
    with open(f'submission.csv', 'w') as f:
        f.write('image_path,dataset,scene,rotation_matrix,translation_vector\n')
        for dataset in data_dict:
            if dataset in out_results:
                res = out_results[dataset]
            else:
                res = {}
            for scene in data_dict[dataset]:
                if scene in res:
                    scene_res = res[scene]
                else:
                    scene_res = {"R":{}, "t":{}}
                for image in data_dict[dataset][scene]:
                    if image in scene_res:
                        print(image)
                        R = scene_res[image]['R'].reshape(-1)
                        T = scene_res[image]['t'].reshape(-1)
                    else:
                        R = np.eye(3).reshape(-1)
                        T = np.zeros((3))
                    f.write(f'{image},{dataset},{scene},{arr_to_str(R)},{arr_to_str(T)}\n')

In [None]:
import pandas as pd

In [None]:
if DEBUG:
    dataset_type = 'train'
else:
    dataset_type = 'test'
    
gc.collect()
datasets = []
for dataset in data_dict:
    datasets.append(dataset)

for dataset in datasets:
    print(dataset)
    if dataset not in out_results:
        out_results[dataset] = {}
    for scene in data_dict[dataset]:
        print(scene)
        # Fail gently if the notebook has not been submitted and the test data is not populated.
        # You may want to run this on the training data in that case?
        img_dir = f'{src}/{dataset_type}/{dataset}/{scene}/images'
        if not os.path.exists(img_dir):
            continue
        # Wrap the meaty part in a try-except block.
        try:
            out_results[dataset][scene] = {}
            img_fnames = [f'{src}/{dataset_type}/{x}' for x in data_dict[dataset][scene]]
            print (f"Got {len(img_fnames)} images")
            feature_dir = f'/kaggle/temp/featureout/{dataset}_{scene}'
            if not os.path.isdir(feature_dir):
                os.makedirs(feature_dir, exist_ok=True)
            t=time()
            index_pairs = get_image_pairs_shortlist(img_fnames,
                                  sim_th = 0.5, # should be strict
                                  min_pairs = 35, # we select at least min_pairs PER IMAGE with biggest similarity
                                  exhaustive_if_less = 20,
                                  device=device)
            t=time() -t 
            timings['shortlisting'].append(t)
            print (f'{len(index_pairs)}, pairs to match, {t:.4f} sec')
            gc.collect()
            t=time()

            if len(index_pairs) >= 400:
#                 detect_features(img_fnames, 
#                         2048*4,
#                         feature_dir=feature_dir,
#                         upright=False,
#                         device=device,
#                         resize_small_edge_to=1024
#                         )
#                 gc.collect()
#                 t=time() -t 
#                 timings['feature_detection'].append(t)
#                 print(f'Features detected in  {t:.4f} sec')
#                 t=time()
#                 match_features(img_fnames, index_pairs, feature_dir=feature_dir,device=device)
                match_superglue(img_fnames, index_pairs, feature_dir=feature_dir, device=device, resize_to_=(600, 800))
                mapper_options = pycolmap.IncrementalMapperOptions()
                mapper_options.min_model_size = 3
                mapper_options.ba_local_max_refinements = 2
#                 mapper_options.ba_global_max_refinements = 20
            else:
                match_loftr_superglue(img_fnames, index_pairs, feature_dir=feature_dir, device=device, resize_to_=(600, 800))
                mapper_options = pycolmap.IncrementalMapperOptions()
                mapper_options.min_model_size = 3
                mapper_options.ba_local_max_refinements = 2
#                 mapper_options.ba_global_max_refinements = 10

            t=time() -t 
            timings['feature_matching'].append(t)
            print(f'Features matched in  {t:.4f} sec')
            database_path = f'{feature_dir}/colmap.db'
            if os.path.isfile(database_path):
                os.remove(database_path)
            gc.collect()
            import_into_colmap(img_dir, feature_dir=feature_dir,database_path=database_path)
            output_path = f'{feature_dir}/colmap_rec_{LOCAL_FEATURE}'

            t=time()
            pycolmap.match_exhaustive(database_path)
            t=time() - t 
            timings['RANSAC'].append(t)
            print(f'RANSAC in  {t:.4f} sec')

            t=time()
            # By default colmap does not generate a reconstruction if less than 10 images are registered. Lower it to 3.

            
            os.makedirs(output_path, exist_ok=True)
            maps = pycolmap.incremental_mapping(database_path=database_path, image_path=img_dir, output_path=output_path, options=mapper_options)

            print(maps)
            #clear_output(wait=False)
            t=time() - t
            timings['Reconstruction'].append(t)
            print(f'Reconstruction done in  {t:.4f} sec')
            imgs_registered  = 0
            best_idx = None
            print ("Looking for the best reconstruction")
            if isinstance(maps, dict):
                for idx1, rec in maps.items():
                    print (idx1, rec.summary())
                    if len(rec.images) > imgs_registered:
                        imgs_registered = len(rec.images)
                        best_idx = idx1
            if best_idx is not None:
                print (maps[best_idx].summary())
                for k, im in maps[best_idx].images.items():
                    key1 = f'{dataset}/{scene}/images/{im.name}'
                    out_results[dataset][scene][key1] = {}
                    out_results[dataset][scene][key1]["R"] = deepcopy(im.rotmat())
                    out_results[dataset][scene][key1]["t"] = deepcopy(im.tvec)

            ##################################
            if isinstance(maps, dict):
                for idx1, rec in maps.items():
                    poses = rec.images.items()
                    for k, im in poses:
                        key1 = f'{dataset}/{scene}/images/{im.name}'
                        if key1 in out_results[dataset][scene]:
                            continue
                        else:
                            out_results[dataset][scene][key1] = {}
                            out_results[dataset][scene][key1]["R"] = deepcopy(im.rotmat())
                            out_results[dataset][scene][key1]["t"] = deepcopy(im.tvec)

            ##################################
            print(f'Registered: {dataset} / {scene} -> {len(out_results[dataset][scene])} images')
            print(f'Total: {dataset} / {scene} -> {len(data_dict[dataset][scene])} images')
            create_submission(out_results, data_dict)
        except:
            pass        
        gc.collect()
#     break

2cfa01ab573141e4
2fa124afd1f74f38


In [None]:
create_submission(out_results, data_dict)