
<center>
    <h2 style="color: #022047"> An image to get in the right mood... 😉  </h2>
</center>

![](https://storage.googleapis.com/kaggle-media/competitions/google-image-matching/trevi-canvas-licensed-nonoderivs.jpg)

In [None]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import gc


import torch
if not torch.cuda.is_available():
    print('You may want to enable the GPU switch?')

INSTALLED_LOG = {}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# OpenGlue

In [None]:
%%writefile /tmp/openglue_comment_out_cell.txt

# Install OmegaConf
!cp ../input/detectron2-v06/antlr4-python3-runtime-4.8.targz_compressed /tmp/antlr4-python3-runtime-4.8.tar.gz
!pip install /tmp/antlr4-python3-runtime-4.8.tar.gz
!pip install ../input/detectron2-v06/omegaconf-2.1.2-py3-none-any.whl

# Make include dir for OpenGlue
!mkdir /tmp/openglue/
!cp -r ../input/openglue/OpenGlue-main /tmp/openglue/openglue
!touch /tmp/openglue/openglue/__init__.py

# Dump pretrained checkpoints
!mkdir -p pretrained/checkpoints
!cp ../input/openglue-models/* pretrained/checkpoints/
!ls pretrained/checkpoints/


import sys
sys.path.append('/tmp/openglue/openglue')
import cv2
import os
from omegaconf import OmegaConf
from typing import Dict, Optional
import argparse
import matplotlib.pyplot as plt

import kornia as K
import kornia.feature as KF
import kornia_moons.feature as KMF

import torch
import torch.nn as nn

import openglue
from openglue.inference import initialize_models, OpenGlueMatcher
from openglue.models.features import get_feature_extractor
from openglue.models.superglue.superglue import SuperGlue as OpenGlueSuperGlue


class OpenGlueMatcherWrapper:
    def __init__(self, device=None):
        torch.hub.set_dir('/kaggle/working/pretrained/')
        local_feature_extractor = get_feature_extractor('OPENCVDoGAffNetHardNet')(max_keypoints=2048, nms_diameter=9)
        local_feature_extractor.to(device)
        self.device = device
        
        yaml_cfg_string_new = """
                              inference:
                                match_threshold: 0.2

                              superglue:
                                descriptor_dim: &DESCRIPTOR_DIM 128
                                laf_to_sideinfo_method: affine
                                positional_encoding:
                                  hidden_layers_sizes: [ 32, 64, 128 ]
                                  side_info_size: 6
                                  output_size: *DESCRIPTOR_DIM
                                attention_gnn:
                                  num_stages: 9
                                  num_heads: 4
                                  embed_dim: *DESCRIPTOR_DIM
                                  attention: 'softmax'
                                  use_offset: False
                                dustbin_score_init: 1.0
                                otp:
                                  num_iters: 20
                                  reg: 1.0
                                residual: True
                              """
        self.config = OmegaConf.create(yaml_cfg_string_new)
        
        state_dict = torch.load('../input/openglue-models/openglue_SIFT-Affnet-Hardnet.ckpt', map_location='cpu')['state_dict']
        for key in list(state_dict.keys()):
            state_dict[key.replace('superglue.', '')] = state_dict.pop(key)
        superglue_matcher = OpenGlueSuperGlue(self.config['superglue'])
        message = superglue_matcher.load_state_dict(state_dict)
        print(message)
        superglue_matcher.to(device)
        
        self._openglue_matcher = OpenGlueMatcher(local_feature_extractor, superglue_matcher, self.config)
        
    def load_torch_image(self, image, resize_to=None):
        timg = K.color.bgr_to_grayscale(K.image_to_tensor(image, False) / 255.).to(self.device)
        if resize_to is not None:
            new_w, new_h = resize_to
            timg = K.geometry.resize(timg, (new_h, new_w))
        return timg

    def __call__(self, img_np1, img_np2):
        with torch.no_grad():
            img_ts1 = self.load_torch_image(img_np1)
            img_ts2 = self.load_torch_image(img_np2)
            out = self._openglue_matcher({"image0": img_ts1, "image1": img_ts2})
            mkpts0 = out['keypoints0'].cpu().numpy()
            mkpts1 = out['keypoints1'].cpu().numpy()
            return mkpts0, mkpts1
        
        
openglue_matcher = OpenGlueMatcherWrapper(device=device)

## Kornia LoFTR

In [None]:
# Install Kornia
force_kornialoftr_reinstall = False

if 'KorniaLoFTR' not in INSTALLED_LOG or force_kornialoftr_reinstall:
    dry_run = False
    !pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
    !pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl
    INSTALLED_LOG['KorniaLoFTR'] = True
else:
    print('Already installed KorniaLoFTR. Set "force_kornialoftr_reinstall=True" to override this behavior.')
    

# Import and use Kornia
import kornia
import kornia as K
import kornia.feature as KF
import kornia_moons.feature as KMF


class LoFTRMatcher:
    def __init__(self, device=None, input_longside=1200, conf_th=None):
        self._loftr_matcher = KF.LoFTR(pretrained=None)
        self._loftr_matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
        self._loftr_matcher = self._loftr_matcher.to(device).eval()
        self.device = device
        self.input_longside = input_longside
        self.conf_thresh = conf_th
        
    def prep_img(self, img, long_side=1200):
        if long_side is not None:
            scale = long_side / max(img.shape[0], img.shape[1]) 
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
            img = cv2.resize(img, (w, h))
        else:
            scale = 1.0

        img_ts = K.image_to_tensor(img, False).float() / 255.
        img_ts = K.color.bgr_to_rgb(img_ts)
        img_ts = K.color.rgb_to_grayscale(img_ts)
        return img, img_ts.to(self.device), scale
    
    def tta_rotation_preprocess(self, img_np, angle):
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        rot_img = cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0]))

        rot_img_ts = K.image_to_tensor(rot_img, False).float() / 255.
        rot_img_ts = K.color.bgr_to_rgb(rot_img_ts)
        rot_img_ts = K.color.rgb_to_grayscale(rot_img_ts)
        return rot_M, rot_img_ts.to(self.device), rot_M_inv

    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask

    def __call__(self, img_np1, img_np2, tta=['orig']):
        with torch.no_grad():
            img_np1, img_ts0, scale0 = self.prep_img(img_np1, long_side=self.input_longside)
            img_np2, img_ts1, scale1 = self.prep_img(img_np2, long_side=self.input_longside)
            images0, images1 = [], []

            # TTA
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_ts0, img_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_ts1, [2, ])
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np1, 10)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np2, 10)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np1, -10)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np2, -10)
                else:
                    raise ValueError('Unknown TTA method.')
                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            # Inference
            input_dict = {"image0": torch.cat(images0), "image1": torch.cat(images1)}
            correspondences = self._loftr_matcher(input_dict)
            mkpts0 = correspondences['keypoints0'].cpu().numpy()
            mkpts1 = correspondences['keypoints1'].cpu().numpy()
            batch_id = correspondences['batch_indexes'].cpu().numpy()
            confidence = correspondences['confidence'].cpu().numpy()

            # Reverse TTA
            for idx, tta_elem in enumerate(tta):
                batch_mask = batch_id == idx

                if tta_elem == 'orig':
                    pass
                elif tta_elem == 'flip_lr':
                    mkpts0[batch_mask, 0] = img_np1.shape[1] - mkpts0[batch_mask, 0]
                    mkpts1[batch_mask, 0] = img_np2.shape[1] - mkpts1[batch_mask, 0]
                elif tta_elem == 'flip_ud':
                    mkpts0[batch_mask, 1] = img_np1.shape[0] - mkpts0[batch_mask, 1]
                    mkpts1[batch_mask, 1] = img_np2.shape[0] - mkpts1[batch_mask, 1]
                elif tta_elem == 'rot_r10':
                    mkpts0[batch_mask], mask0 = self.tta_rotation_postprocess(mkpts0[batch_mask], img_np1, rot_r10_M0_inv)
                    mkpts1[batch_mask], mask1 = self.tta_rotation_postprocess(mkpts1[batch_mask], img_np2, rot_r10_M1_inv)
                    confidence[batch_mask] += (~(mask0 & mask1)).astype(np.float32) * -10.
                elif tta_elem == 'rot_l10':
                    mkpts0[batch_mask], mask0 = self.tta_rotation_postprocess(mkpts0[batch_mask], img_np1, rot_l10_M0_inv)
                    mkpts1[batch_mask], mask1 = self.tta_rotation_postprocess(mkpts1[batch_mask], img_np2, rot_l10_M1_inv)
                    confidence[batch_mask] += (~(mask0 & mask1)).astype(np.float32) * -10.
                else:
                    raise ValueError('Unknown TTA method.')
                    
            if self.conf_thresh is not None:
                th_mask = confidence >= self.conf_thresh
            else:
                th_mask = confidence >= 0.
            mkpts0, mkpts1 = mkpts0[th_mask, :], mkpts1[th_mask, :]

            # Matching points
            return mkpts0 / scale0, mkpts1 / scale1
        

# 1200 is the validation size, according to the paper
loftr_matcher = LoFTRMatcher(device=device, input_longside=1200, conf_th=0.3)

## SuperGlue

In [None]:
# Install superglue
force_superglue_reinstall = False

if 'superglue' not in INSTALLED_LOG or force_superglue_reinstall:
    !mkdir /tmp/superpoint
    !cp -r ../input/super-glue-pretrained-network/models /tmp/superpoint/superpoint
    !ls /tmp/superpoint/superpoint
    !touch /tmp/superpoint/superpoint/__init__.py
    INSTALLED_LOG['superglue'] = True
else:
    print('Already installed SuperGlue. Set "force_superglue_reinstall=True" to override this behavior.')

# Import superglue
import sys
sys.path.append("/tmp/superpoint")
from superpoint.superpoint import SuperPoint
from superpoint.superglue import SuperGlue


class SuperGlueCustomMatchingV2(torch.nn.Module):
    """ Image Matching Frontend (SuperPoint + SuperGlue) """
    def __init__(self, config={}, device=None):
        super().__init__()
        self.superpoint = SuperPoint(config.get('superpoint', {}))
        self.superglue = SuperGlue(config.get('superglue', {}))

        self.tta_map = {
            'orig': self.untta_none,
            'eqhist': self.untta_none,
            'clahe': self.untta_none,
            'flip_lr': self.untta_fliplr,
            'flip_ud': self.untta_flipud,
            'rot_r10': self.untta_rotr10,
            'rot_l10': self.untta_rotl10,
            'fliplr_rotr10': self.untta_fliplr_rotr10,
            'fliplr_rotl10': self.untta_fliplr_rotl10
        }
        self.device = device

    def forward_flat(self, data, ttas=['orig', ], tta_groups=[['orig']]):
        """ Run SuperPoint (optionally) and SuperGlue
        SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
        Args:
          data: dictionary with minimal keys: ['image0', 'image1']
        """
        pred = {}

        # Extract SuperPoint (keypoints, scores, descriptors) if not provided
        # sp_st = time.time()
        if 'keypoints0' not in data:
            pred0 = self.superpoint({'image': data['image0']})
            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
        if 'keypoints1' not in data:
            pred1 = self.superpoint({'image': data['image1']})
            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
        # sp_nd = time.time()
        # print('SP:', sp_nd - sp_st, 's')

        # Reverse-tta before inference
        pred['scores0'] = list(pred['scores0'])
        pred['scores1'] = list(pred['scores1'])
        for i in range(len(pred['keypoints0'])):
            pred['keypoints0'][i], pred['descriptors0'][i], pred['scores0'][i] = self.tta_map[ttas[i]](
                pred['keypoints0'][i], pred['descriptors0'][i], pred['scores0'][i],
                w=data['image0'].shape[3], h=data['image0'].shape[2], inplace=True, mask_illegal=True)

            pred['keypoints1'][i], pred['descriptors1'][i], pred['scores1'][i] = self.tta_map[ttas[i]](
                pred['keypoints1'][i], pred['descriptors1'][i], pred['scores1'][i],
                w=data['image1'].shape[3], h=data['image1'].shape[2], inplace=True, mask_illegal=True)

        # Batch all features
        # We should either have i) one image per batch, or
        # ii) the same number of local features for all images in the batch.
        data = {**data, **pred}

        group_preds = []
        for tta_group in tta_groups:
            group_mask = torch.from_numpy(np.array([x in tta_group for x in ttas], dtype=np.bool))
            group_data = {
                **{f'keypoints{k}': [data[f'keypoints{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'descriptors{k}': [data[f'descriptors{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'scores{k}': [data[f'scores{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'image{k}': data[f'image{k}'][group_mask, ...] for k in [0, 1]},
            }
            for k, v in group_data.items():
                if isinstance(group_data[k], (list, tuple)):
                    if k.startswith('descriptor'):
                        group_data[k] = torch.cat(group_data[k], 1)[None, ...]
                    else:
                        group_data[k] = torch.cat(group_data[k])[None, ...]
                else:
                    group_data[k] = torch.flatten(group_data[k], 0, 1)[None, ...]
            # sg_st = time.time()
            group_pred = {
                # **{k: group_data[k] for k in group_data},
                **group_data,
                **self.superglue(group_data)
            }
            # sg_nd = time.time()
            # print('SG:', sg_nd - sg_st, 's')
            group_preds.append(group_pred)
        return group_preds

    def forward_cross(self, data, ttas=['orig', ], tta_groups=[('orig', 'orig')]):
        pred = {}

        # Extract SuperPoint (keypoints, scores, descriptors) if not provided
        sp_st = time.time()
        if 'keypoints0' not in data:
            pred0 = self.superpoint({'image': data['image0']})
            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
        if 'keypoints1' not in data:
            pred1 = self.superpoint({'image': data['image1']})
            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
        sp_nd = time.time()

        # Batch all features
        # We should either have i) one image per batch, or
        # ii) the same number of local features for all images in the batch.
        data = {**data, **pred}

        # Group predictions (list, with elements with matches{0,1}, matching_scores{0,1} keys)
        group_pred_list = []
        tta2id = {k: i for i, k in enumerate(ttas)}
        for tta_group in tta_groups:
            group_idx = tta2id[tta_group[0]], tta2id[tta_group[1]]
            group_data = {
                **{f'image{i}': data[f'image{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'keypoints{i}': data[f'keypoints{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'descriptors{i}': data[f'descriptors{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'scores{i}': data[f'scores{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
            }

            for k in group_data:
                if isinstance(group_data[k], (list, tuple)):
                    group_data[k] = torch.stack(group_data[k])

            group_sg_pred = self.superglue(group_data)
            group_pred_list.append(group_sg_pred)

        # UnTTA
        data['scores0'] = list(data['scores0'])
        data['scores1'] = list(data['scores1'])
        for i in range(len(data['keypoints0'])):
            data['keypoints0'][i], data['descriptors0'][i], data['scores0'][i] = self.tta_map[ttas[i]](
                data['keypoints0'][i], data['descriptors0'][i], data['scores0'][i],
                w=data['image0'].shape[3], h=data['image0'].shape[2], inplace=True, mask_illegal=False)

            data['keypoints1'][i], data['descriptors1'][i], data['scores1'][i] = self.tta_map[ttas[i]](
                data['keypoints1'][i], data['descriptors1'][i], data['scores1'][i],
                w=data['image1'].shape[3], h=data['image1'].shape[2], inplace=True, mask_illegal=False)

        # Sooo... groups?
        for group_pred, tta_group in zip(group_pred_list, tta_groups):
            group_idx = tta2id[tta_group[0]], tta2id[tta_group[1]]
            group_pred.update({
                **{f'keypoints{i}': data[f'keypoints{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'scores{i}': data[f'scores{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
            })
        return group_pred_list


    def untta_none(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        return keypoints, descriptors, scores
    
    def untta_fliplr(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        keypoints[:, 0] = w - keypoints[:, 0] - 1.
        return keypoints, descriptors, scores

    def untta_flipud(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        keypoints[:, 1] = h - keypoints[:, 1] - 1.
        return keypoints, descriptors, scores

    def untta_rotr10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is +10, inverse is -10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), -15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores

    def untta_rotl10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is -10, inverse is +10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), 15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores
        
    def untta_fliplr_rotr10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is +10, inverse is -10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), -15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        rot_kpts[:, 0] = w - rot_kpts[:, 0] - 1.
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores

    def untta_fliplr_rotl10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # rotr10 is -10, inverse is +10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), 15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        rot_kpts[:, 0] = w - rot_kpts[:, 0] - 1.
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores


class SuperGlueMatcherV2:
    def __init__(self, device=None, conf_th=None):
        config = {
            "superpoint": {
                "nms_radius": 3,
                "keypoint_threshold": 0.005,
                "max_keypoints": 2048,
            },
            "superglue": {
                "weights": "outdoor",
                "sinkhorn_iterations": 100,
                "match_threshold": 0.2,
            }
        }
        self.device = device
        self._superglue_matcher = SuperGlueCustomMatchingV2(
            config=config, device=self.device,
            ).eval().to(device)

        self.conf_thresh = conf_th
    
    def prep_np_img(self, img, long_side=None):
        if long_side is not None:
            scale = long_side / max(img.shape[0], img.shape[1])
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
            img = cv2.resize(img, (w, h))
        else:
            scale = 1.0
        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), scale
    
    def frame2tensor(self, frame):
        return (torch.from_numpy(frame).float()/255.)[None, None].to(self.device)
            
    def tta_rotation_preprocess(self, img_np, angle):
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        rot_img = self.frame2tensor(cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0])))
        return rot_M, rot_img, rot_M_inv

    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask

    def __call__(self, img_np0, img_np1, tta_groups=[['orig']], forward_type='cross', input_longside=None):
        with torch.no_grad():
            img_np0, scale0 = self.prep_np_img(img_np0, input_longside)
            img_np1, scale1 = self.prep_np_img(img_np1, input_longside)

            img_ts0 = self.frame2tensor(img_np0)
            img_ts1 = self.frame2tensor(img_np1)
            images0, images1 = [], []

            tta = []
            for tta_g in tta_groups:
                tta += tta_g
            tta = list(set(tta))

            # TTA
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_ts0, img_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_ts1, [2, ])
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np0, 15)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np1, 15)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np0, -15)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np1, -15)
                elif tta_elem == 'fliplr_rotr10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np0[:, ::-1], 15)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np1[:, ::-1], 15)
                elif tta_elem == 'fliplr_rotl10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np0[:, ::-1], -15)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np1[:, ::-1], -15)
                elif tta_elem == 'eqhist':
                    img_ts0_aug = self.frame2tensor(cv2.equalizeHist(img_np0))
                    img_ts1_aug = self.frame2tensor(cv2.equalizeHist(img_np1))
                elif tta_elem == 'clahe':
                    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
                    img_ts0_aug = self.frame2tensor(clahe.apply(img_np0))
                    img_ts1_aug = self.frame2tensor(clahe.apply(img_np1))
                else:
                    raise ValueError('Unknown TTA method.')

                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            # Inference
            if forward_type == 'cross':
                pred = self._superglue_matcher.forward_cross(
                    data={
                        "image0": torch.cat(images0),
                        "image1": torch.cat(images1)
                    },
                    ttas=tta, tta_groups=tta_groups)
            elif forward_type == 'flat':
                pred = self._superglue_matcher.forward_flat(
                data={
                    "image0": torch.cat(images0),
                    "image1": torch.cat(images1)
                },
                ttas=tta, tta_groups=tta_groups)
            else:
                raise RuntimeError(f'Unknown forward_type {forward_type}')

            mkpts0, mkpts1, mconf = [], [], []
            for group_pred in pred:
                pred_aug = {k: v[0].detach().cpu().numpy().squeeze() for k, v in group_pred.items()}
                kpts0, kpts1 = pred_aug["keypoints0"], pred_aug["keypoints1"]
                matches, conf = pred_aug["matches0"], pred_aug["matching_scores0"]

                if self.conf_thresh is None:
                    valid = matches > -1
                else:
                    valid = (matches > -1) & (conf >= self.conf_thresh)
                mkpts0.append(kpts0[valid])
                mkpts1.append(kpts1[matches[valid]])
                mconf.append(conf[valid])

            cat_mkpts0 = np.concatenate(mkpts0)
            cat_mkpts1 = np.concatenate(mkpts1)
            mask0 = (cat_mkpts0[:, 0] >= 0) & (cat_mkpts0[:, 0] < img_np0.shape[1]) & (cat_mkpts0[:, 1] >= 0) & (cat_mkpts0[:, 1] < img_np0.shape[0])
            mask1 = (cat_mkpts1[:, 0] >= 0) & (cat_mkpts1[:, 0] < img_np1.shape[1]) & (cat_mkpts1[:, 1] >= 0) & (cat_mkpts1[:, 1] < img_np1.shape[0])
            return cat_mkpts0[mask0 & mask1] / scale0, cat_mkpts1[mask0 & mask1] / scale1


# 1600 is the validation size in the paper
superglue_matcher = SuperGlueMatcherV2(device=device, conf_th=0.4)


## DKM

In [None]:
# Install DKM and move checkpoints to the local checkpoints dir
force_dkm_reinstall = False

if 'DKM' not in INSTALLED_LOG or force_dkm_reinstall:
    !mkdir -p pretrained/checkpoints
    !cp /kaggle/input/imc2022-dependencies/pretrained/dkm.pth pretrained/checkpoints/dkm_base_v11.pth

    !pip install -f /kaggle/input/imc2022-dependencies/wheels --no-index einops
    !cp -r /kaggle/input/imc2022-dependencies/DKM/ /kaggle/working/DKM/
    !cd /kaggle/working/DKM/; pip install -f /kaggle/input/imc2022-dependencies/wheels -e .
    INSTALLED_LOG['DKM'] = True
else:
    print('Already installed DKM. Set "force_dkm_reinstall=True" to override this behavior.')

# imports for DKM
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import sys, os, csv
from PIL import Image
import cv2, gc
import matplotlib.pyplot as plt
import torch
import types
import torch.nn.functional as TorchFunc
sys.path.append('/kaggle/input/imc2022-dependencies/DKM/')

dry_run = False


from dkm import dkm_base
from torchvision import transforms


def custom_stable_neighbours(self, query_coords, query_to_support, support_to_query):
    qts = query_to_support
    for t in range(4):
        _qts = qts
        q = TorchFunc.grid_sample(support_to_query, qts, mode="bilinear")
        qts = TorchFunc.grid_sample(
            query_to_support.permute(0, 3, 1, 2),
            q.permute(0, 2, 3, 1),
            mode="bilinear",
        ).permute(0, 2, 3, 1)
    d = (qts - _qts).norm(dim=-1)
    qd = (q - query_coords).norm(dim=1)
    stabneigh = torch.logical_and(d < 1e-3, qd < 5e-3)
    return q, qts, stabneigh


def custom_match(
        self,
        im1,
        im2,
        batched=False,
        check_cycle_consistency=False,
        do_pred_in_og_res=False,
    ):
    self.train(False)
    with torch.no_grad():
        if not batched:
            b = 1
            w, h = im1.size
            w2, h2 = im2.size
            # Get images in good format
            ws = self.w_resized
            hs = self.h_resized
            test_transform = get_tuple_transform_ops(
                resize=(hs, ws), normalize=True
            )
            query, support = test_transform((im1, im2))
            batch = {"query": query[None].cuda(), "support": support[None].cuda()}
        else:
            b, c, h, w = im1.shape
            b, c, h2, w2 = im2.shape
            assert w == w2 and h == h2, "wat"
            batch = {"query": im1.cuda(), "support": im2.cuda()}
            hs, ws = self.h_resized, self.w_resized
        finest_scale = 1  # i will assume that we go to the finest scale (otherwise min(list(dense_corresps.keys())) also works)
        # Run matcher
        if check_cycle_consistency:
            dense_corresps = self.forward_symmetric(batch)
            query_to_support, support_to_query = dense_corresps[finest_scale][
                "dense_flow"
            ].chunk(2)
            query_to_support = query_to_support.permute(0, 2, 3, 1)
            dense_certainty, dc_s = dense_corresps[finest_scale][
                "dense_certainty"
            ].chunk(
                2
            )  # TODO: Here we could also use the reverse certainty
        else:
            dense_corresps = self.forward(batch)
            query_to_support = dense_corresps[finest_scale]["dense_flow"].permute(
                0, 2, 3, 1
            )
            # Get certainty interpolation
            dense_certainty = dense_corresps[finest_scale]["dense_certainty"]

        if do_pred_in_og_res:  # Will assume that there is no batching going on.
            og_query, og_support = self.og_transforms((im1, im2))
            query_to_support, dense_certainty = self.decoder.upsample_preds(
                query_to_support,
                dense_certainty,
                og_query.cuda()[None],
                og_support.cuda()[None],
            )
            hs, ws = h, w
        # Create im1 meshgrid
        query_coords = torch.meshgrid(
            (
                torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda:0"),
                torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda:0"),
            )
        )
        query_coords = torch.stack((query_coords[1], query_coords[0]))
        query_coords = query_coords[None].expand(b, 2, hs, ws)
        dense_certainty = dense_certainty.sigmoid()  # logits -> probs
        if check_cycle_consistency:
            query_coords, query_to_support, stabneigh = self.custom_stable_neighbours(
                query_coords, query_to_support, support_to_query
            )
            dense_certainty *= stabneigh[:, None, :, :].float() + 1e-3
        # Return only matches better than threshold
        query_coords = query_coords.permute(0, 2, 3, 1)

        query_to_support = torch.clamp(query_to_support, -1, 1)
        if batched:
            return torch.cat((query_coords, query_to_support), dim=-1), dense_certainty[:, 0]
        else:
            return torch.cat((query_coords, query_to_support), dim=-1)[0], dense_certainty[0, 0]


class DKMMatcher:
    _DEFAULT_CONFIG = {
        'w': 512, 'h': 384,
        'thresh': {
            'method': 'abs_rnd',
            'th': 0.9,
            'take': 100,
        }
    }
    def __init__(self, device=None, config=_DEFAULT_CONFIG):
        torch.hub.set_dir('/kaggle/working/pretrained/')
        self._dkm_matcher = dkm_base(pretrained=True, version="v11").to(device).eval()
        self.config = config
        self._dkm_matcher.w_resized = config['w']
        self._dkm_matcher.h_resized = config['h']
        self._dkm_matcher.custom_stable_neighbours = types.MethodType(custom_stable_neighbours, self._dkm_matcher)
        self._dkm_matcher.custom_match = types.MethodType(custom_match, self._dkm_matcher)
        self.device=device

        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
        self.normalize = transforms.Normalize(mean=mean, std=std)

    def prepare_torch_image(self, img):
        """img - BGR image array"""
        img_ts = K.image_to_tensor(img, False).float() / 255.
        img_ts = K.color.bgr_to_rgb(img_ts)
        return self.normalize(img_ts)

    def results_thresholding(self, dense_matches, dense_certainty):
        if self.config['thresh']['method'] == 'abs_rnd':
            n_take = (dense_certainty >= self.config['thresh']['th']).count_nonzero().cpu().numpy()
            n_take = min(max(n_take, 0), self.config['thresh']['take'])
            sparse_matches, sparse_certainty = self._dkm_matcher.sample(dense_matches, dense_certainty, num=n_take)
        elif self.config['thresh']['method'] == 'rel':
            sparse_matches, sparse_certainty = self._dkm_matcher.sample(
                dense_matches, dense_certainty,
                num=self.config['thresh']['take'], relative_confidence_threshold=self.config['thresh']['th'])
        elif self.config['thresh']['method'] == 'abs':
            matches, certainty = (
                dense_matches.reshape(-1, 4),
                dense_certainty.reshape(-1),
            )
            th_matches, th_certainty = (
                matches[certainty > self.config['thresh']['th']].cpu().numpy(),
                certainty[certainty > self.config['thresh']['th']].cpu().numpy(),
            )
            if len(th_matches) > 0:
                good_samples = np.random.choice(
                    np.arange(len(th_matches)),
                    size=min(self.config['thresh']['take'], len(th_certainty)),
                    replace=False,
                    p=th_certainty / (np.sum(th_certainty) + 1e-6),
                )
                sparse_matches = th_matches[good_samples]
            else:
                sparse_matches = th_matches
        else:
            raise ValueError('Unknown thresholding method ' + str(self.config['thresh']['method']))
        return sparse_matches

    def tta_rotation_preprocess(self, img_np, angle):
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        rot_img = cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0]))
        return rot_M, rot_img, rot_M_inv

    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask


    def __call__(self, img_bgr_np1, img_bgr_np2, tta=['orig'], verbose=0):
        with torch.no_grad():            
            # TTA preparation. Affine first, and rotation only after
            ttaprep_st = time.time()  
            images0, images1 = [], []
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_np0_aug, img_np1_aug = img_bgr_np1, img_bgr_np2
                elif tta_elem == 'flip_lr':
                    img_np0_aug = np.flip(img_bgr_np1, [1, ]).copy()
                    img_np1_aug = np.flip(img_bgr_np2, [1, ]).copy()
                elif tta_elem == 'flip_ud':
                    img_np0_aug = np.flip(img_bgr_np1, [0, ]).copy()
                    img_np1_aug = np.flip(img_bgr_np2, [0, ]).copy()
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_np0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_bgr_np1, 10)
                    rot_r10_M1, img_np1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_bgr_np2, 10)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_np0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_bgr_np1, -10)
                    rot_l10_M1, img_np1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_bgr_np2, -10)
                else:
                    raise ValueError('Unknown TTA method.')

                # Rotation is after for 2 reasons:
                #   - I think rotation on smaller scale might lose some information
                #   - Avoid any weird stuff in combining affine transformations
                img_np0_aug = cv2.resize(img_np0_aug, (self._dkm_matcher.w_resized, self._dkm_matcher.h_resized))
                img_np1_aug = cv2.resize(img_np1_aug, (self._dkm_matcher.w_resized, self._dkm_matcher.h_resized))
                images0.append(self.prepare_torch_image(img_np0_aug))
                images1.append(self.prepare_torch_image(img_np1_aug))
            ttaprep_nd = time.time()

            # Batched inference
            batchinf_st = time.time()  # -- start
            img0_batch_ts, img1_batch_ts = torch.cat(images0), torch.cat(images1)
            dense_matches, dense_certainty = self._dkm_matcher.custom_match(
                img0_batch_ts, img1_batch_ts, batched=True,
                check_cycle_consistency=self.config['check_cycle_consistency'])
            batchinf_nd = time.time()  # -- end

            # SQRT of dense certainty
            dense_certainty = dense_certainty.sqrt()

            # Get sparse matching keypoints
            getkpts_st = time.time()
            mkps1, mkps2 = [], []
            for idx, tta_elem in enumerate(tta):
                ith_sparse_matches = self.results_thresholding(dense_matches[idx], dense_certainty[idx])

                aug_mkps1 = ith_sparse_matches[:, :2]
                aug_mkps2 = ith_sparse_matches[:, 2:]

                h, w, c = img_bgr_np1.shape
                aug_mkps1[:, 0] = ((aug_mkps1[:, 0] + 1.)/2.) * w
                aug_mkps1[:, 1] = ((aug_mkps1[:, 1] + 1.)/2.) * h

                h, w, c = img_bgr_np2.shape
                aug_mkps2[:, 0] = ((aug_mkps2[:, 0] + 1.)/2.) * w
                aug_mkps2[:, 1] = ((aug_mkps2[:, 1] + 1.)/2.) * h
                
                mkps1.append(aug_mkps1)
                mkps2.append(aug_mkps2)
            getkpts_nd = time.time()

            # Reverse TTA
            revtta_st = time.time()
            for idx, tta_elem in enumerate(tta):
                if tta_elem == 'orig':
                    pass
                elif tta_elem == 'flip_lr':
                    mkps1[idx][:, 0] = img_bgr_np1.shape[1] - mkps1[idx][:, 0]
                    mkps2[idx][:, 0] = img_bgr_np2.shape[1] - mkps2[idx][:, 0]
                elif tta_elem == 'flip_ud':
                    mkps1[idx][:, 1] = img_bgr_np1.shape[0] - mkps1[idx][:, 1]
                    mkps2[idx][:, 1] = img_bgr_np2.shape[0] - mkps2[idx][:, 1]
                elif tta_elem == 'rot_r10':
                    mkps1[idx], mask0 = self.tta_rotation_postprocess(mkps1[idx], img_bgr_np1, rot_r10_M0_inv)
                    mkps2[idx], mask1 = self.tta_rotation_postprocess(mkps2[idx], img_bgr_np2, rot_r10_M1_inv)
                    mkps1[idx], mkps2[idx] = mkps1[idx][mask0 & mask1], mkps2[idx][mask0 & mask1]
                elif tta_elem == 'rot_l10':
                    mkps1[idx], mask0 = self.tta_rotation_postprocess(mkps1[idx], img_bgr_np1, rot_l10_M0_inv)
                    mkps2[idx], mask1 = self.tta_rotation_postprocess(mkps2[idx], img_bgr_np2, rot_l10_M1_inv)
                    mkps1[idx], mkps2[idx] = mkps1[idx][mask0 & mask1], mkps2[idx][mask0 & mask1]
                else:
                    raise ValueError('Unknown TTA method.')
            revtta_nd = time.time()

            if verbose >= 1:
                print('    - DKM inner:')
                print('      - ttaprep:', ttaprep_nd - ttaprep_st)
                print('      - batchinf:', batchinf_nd - batchinf_st)
                print('      - getkpts:', getkpts_nd - getkpts_st)
                print('      - revtta:', revtta_nd - revtta_st)
            return np.concatenate(mkps1), np.concatenate(mkps2)


# abs_rnd_config = {
#     'w': 512, 'h': 384,
#     'thresh': {'method': 'abs', 'th': 0.9, 'take': 100}
# }
abs_config = {
    'w': 800, 'h': 600, 'check_cycle_consistency': False,
    'thresh': {'method': 'abs', 'th': 0.8, 'take': 200}
} 
# rel_config = {
#     'w': 512, 'h': 384,
#     'thresh': {'method': 'rel', 'th': 0.95, 'take': 69}
# }
# abs_sqrt_config = {  # Abs sqrt is actualy abs_pow(1/4) because the dense_confidence score is already sqrt-processed, so double-sqrt
#     'w': 800, 'h': 600,
#     'thresh': {'method': 'abs_sqrt', 'th': 0.8, 'take': 250}
# }

dkm_matcher = DKMMatcher(device=device, config=abs_config)

# DISK

In [None]:
%%writefile /tmp/disk_install.txt

# Install DISK
force_disk_reinstall = False

if not INSTALLED_LOG.get('disk_installed', False) or force_disk_reinstall:
    # !pip install ../input/disk-deps/disk_repo/submodules/torch-localize --no-index --find-links ../input/disk-deps/disk_repo/wheel_files
    # !pip install ../input/disk-deps/disk_repo/submodules/torch-dimcheck --no-index --find-links ../input/disk-deps/disk_repo/wheel_files
    # !pip install ../input/disk-deps/disk_repo/submodules/unets --no-index --find-links ../input/disk-deps/disk_repo/wheel_files
    !pip install -f /kaggle/input/k/eduardtrulls/imc2022-dependencies/wheels --no-index torch_dimcheck
    !pip install -f /kaggle/input/k/eduardtrulls/imc2022-dependencies/wheels --no-index torch_localize
    !pip install -f /kaggle/input/k/eduardtrulls/imc2022-dependencies/wheels --no-index unets
    !pip install -f /kaggle/input/k/eduardtrulls/imc2022-dependencies/wheels --no-index disk
else:
    print('Already installed DISK. Set "force_disk_reinstall=True" to override this behavior.')

In [None]:
%%writefile /tmp/disk_usage.txt

import disk
from disk import DISK, Features


from torch_dimcheck import dimchecked
from functools import partial
from disk.geom import distance_matrix
import math


DISK_REPO_PATH = '../input/disk-deps/disk_repo/'


class DiskImage:
    def __init__(self, bitmap: ['C', 'H', 'W'], fname: str, orig_shape=None):
        self.bitmap     = bitmap
        self.fname      = fname
        if orig_shape is None:
            self.orig_shape = self.bitmap.shape[1:]
        else:
            self.orig_shape = orig_shape

    def resize_to(self, shape):
        return DiskImage(
            self._pad(self._interpolate(self.bitmap, shape), shape),
            self.fname,
            orig_shape=self.bitmap.shape[1:],
        )

    @dimchecked
    def to_image_coord(self, xys: [2, 'N']) -> ([2, 'N'], ['N']):
        f, _size = self._compute_interpolation_size(self.bitmap.shape[1:])
        scaled = xys / f

        h, w = self.orig_shape
        x, y = scaled

        mask = (0 <= x) & (x < w) & (0 <= y) & (y < h)

        return scaled, mask

    def _compute_interpolation_size(self, shape):
        x_factor = self.orig_shape[0] / shape[0]
        y_factor = self.orig_shape[1] / shape[1]

        f = 1 / max(x_factor, y_factor)

        if x_factor > y_factor:
            new_size = (shape[0], int(f * self.orig_shape[1]))
        else:
            new_size = (int(f * self.orig_shape[0]), shape[1])

        return f, new_size

    @dimchecked
    def _interpolate(self, image: ['C', 'H', 'W'], shape) -> ['C', 'h', 'w']:
        _f, size = self._compute_interpolation_size(shape)
        return F.interpolate(
            image.unsqueeze(0),
            size=size,
            mode='bilinear',
            align_corners=False,
        ).squeeze(0)
    
    @dimchecked
    def _pad(self, image: ['C', 'H', 'W'], shape) -> ['C', 'h', 'w']:
        x_pad = shape[0] - image.shape[1]
        y_pad = shape[1] - image.shape[2]

        if x_pad < 0 or y_pad < 0:
            raise ValueError("Attempting to pad by negative value")

        return F.pad(image, (0, y_pad, 0, x_pad))


class DiskMatcher:

    MAX_FULL_MATRIX = 10000**2
    _help=('this is the biggest match matrix that will attempt to be '
           'computed allocated in memory. Matrices bigger than that will '
           'be split into chunks of at most this size. Reduce if your '
           'script runs out of memory.')

    @staticmethod
    @dimchecked
    def _binary_to_index(binary_mask: ['N'], ix2: ['M']) -> [2, 'M']:
        return torch.stack([
            torch.nonzero(binary_mask, as_tuple=False)[:, 0],
            ix2
        ], dim=0)

    @staticmethod
    @dimchecked
    def _ratio_one_way(dist_m: ['N', 'M'], rt) -> [2, 'K']:
        val, ix = torch.topk(dist_m, k=2, dim=1, largest=False)
        ratio = val[:, 0] / val[:, 1]
        passed_test = ratio < rt
        ix2 = ix[passed_test, 0]

        return DiskMatcher._binary_to_index(passed_test, ix2)

    @staticmethod
    @dimchecked
    def _match_chunkwise(ds1: ['N', 'F'], ds2: ['M', 'F'], rt) -> [2, 'K']:
        chunk_size = DiskMatcher.MAX_FULL_MATRIX // ds1.shape[0]
        matches = []
        start = 0

        while start < ds2.shape[0]:
            ds2_chunk = ds2[start:start+chunk_size]
            dist_m = distance_matrix(ds1, ds2_chunk)
            one_way = DiskMatcher._ratio_one_way(dist_m, rt)
            one_way[1] += start
            matches.append(one_way)
            start += chunk_size

        return torch.cat(matches, dim=1)

    @staticmethod
    @dimchecked
    def _match(ds1: ['N', 'F'], ds2: ['M', 'F'], rt) -> [2, 'K']:
        size = ds1.shape[0] * ds2.shape[0]

        fwd = DiskMatcher._match_chunkwise(ds1, ds2, rt)
        bck = DiskMatcher._match_chunkwise(ds2, ds1, rt)
        bck = torch.flip(bck, (0, ))

        merged = torch.cat([fwd, bck], dim=1)
        unique, counts = torch.unique(merged, dim=1, return_counts=True)

        return unique[:, counts == 2]

    @staticmethod
    def match(desc_1, desc_2, rt=1., u16=False):
        matched_pairs = DiskMatcher._match(desc_1, desc_2, rt)
        matches = matched_pairs.cpu().numpy()

        if u16:
            matches = matches.astype(np.uint16)

        return matches


class DiskWrapper:
    def __init__(self, n=2048, matcher='cv2_bf', long_side=1024, device=device):
        state_dict = torch.load(os.path.join(DISK_REPO_PATH, 'depth-save.pth'), map_location='cpu')

        # compatibility with older model saves which used the 'extractor' name
        if 'extractor' in state_dict:
            weights = state_dict['extractor']
        elif 'disk' in state_dict:
            weights = state_dict['disk']
        else:
            raise KeyError('Incompatible weight file!')

        self.model = DISK(window=8, desc_dim=128)
        self.model.load_state_dict(weights)
        self.model = self.model.to(device)

        self._extract = partial(
            self.model.features,
            kind='nms',
            window_size=3,  # NMS window size
            cutoff=0.,
            n=n,  # None means unlimited keypoints
        )

        self.matcher = matcher
        if self.matcher == 'cv2_bf':
            # Brute-force matcher with bi-directionaly check.
            self._cv2_match_fn = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
        elif self.matcher == 'disk':
            # Match using the same algo as in DISK repo
            self._disk_match_fn = partial(DiskMatcher.match, rt=1., u16=False)
        else:
            raise RuntimeError(f'Unknown matcher {self.matcher}')

        self.long_side = long_side
        self.device = device

    def prep_img(self, img, long_side=1024):
        # Resize so that the longest side is "long_side"
        if long_side is not None:
            scale = long_side / max(img.shape[0], img.shape[1]) 
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
        else:
            w, h = img.shape[1], img.shape[0]

        # round sides up to multiples of 16
        if not w % 16 == 0:
            w = int(math.ceil(w / 16.) * 16)
        if not h % 16 == 0:
            h = int(math.ceil(h / 16.) * 16)
        
        # To tensor
        scalew, scaleh = float(w) / img.shape[1], float(h) / img.shape[0]
        img = cv2.resize(img, (w, h))
        img_ts = K.image_to_tensor(img, False).float() / 255.
        img_ts = K.color.bgr_to_rgb(img_ts)
        return img_ts.to(self.device), scalew, scaleh

    def extract_features(self, bitmap_batch):
        images = []
        for i, bitmap in enumerate(bitmap_batch):
            images.append(DiskImage(bitmap, f'img{i}'))

        with torch.no_grad():
            batched_features = self._extract(bitmap_batch)

            ret_keypoints, ret_descriptors, ret_scores = [], [], []
            for features, image in zip(batched_features.flat, images):
                # features = features.to('cpu')

                kps_crop_space = features.kp.T
                kps_img_space, mask = image.to_image_coord(kps_crop_space)

                keypoints   = kps_img_space.T[mask]
                descriptors = features.desc[mask]
                scores      = features.kp_logp[mask]

                order = torch.flip(torch.argsort(scores), [0, ])

                keypoints   = keypoints[order]
                descriptors = descriptors[order]
                scores      = scores[order]

                ret_keypoints.append(keypoints)
                ret_descriptors.append(descriptors)
                ret_scores.append(scores)

        return ret_keypoints, ret_descriptors, ret_scores

    def __call__(self, img_bgr_np0, img_bgr_np1, tta=['orig', ]):
        with torch.no_grad():
            img_rgb_ts0, scale0_w, scale0_h = self.prep_img(img_bgr_np0)
            img_rgb_ts1, scale1_w, scale1_h = self.prep_img(img_bgr_np1)

            # TTA
            images0, images1 = [], []
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_rgb_ts0, img_rgb_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_rgb_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_rgb_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_rgb_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_rgb_ts1, [2, ])
                else:
                    raise ValueError('Unknown TTA method.')
                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            batch_images0 = torch.cat(images0)
            batch_images1 = torch.cat(images1)

            # Batched inference
            keypoints0, descriptors0, scores0 = self.extract_features(batch_images0)
            keypoints1, descriptors1, scores1 = self.extract_features(batch_images1)

            # Match and post-process
            mkpts0, mkpts1 = [], []
            for idx, tta_elem in enumerate(tta):
                kpts0, kpts1 = keypoints0[idx], keypoints1[idx]
                desc0, desc1 = descriptors0[idx], descriptors1[idx]
                scor0, scor1 = scores0[idx], scores1[idx]

                # Re-scale keypoints to the size of original image
                kpts0[:, 0], kpts0[:, 1] = kpts0[:, 0] / scale0_w, kpts0[:, 1] / scale0_h
                kpts1[:, 0], kpts1[:, 1] = kpts1[:, 0] / scale1_w, kpts1[:, 1] / scale1_h

                # Reverse TTA on keypoints
                if tta_elem == 'orig':
                    pass
                elif tta_elem == 'flip_lr':
                    kpts0[:, 0] = img_bgr_np0.shape[1] - kpts0[:, 0]
                    kpts1[:, 0] = img_bgr_np1.shape[1] - kpts1[:, 0]
                elif tta_elem == 'flip_ud':
                    kpts0[:, 1] = img_bgr_np0.shape[0] - kpts0[:, 1]
                    kpts1[:, 1] = img_bgr_np1.shape[0] - kpts1[:, 1]
                else:
                    raise ValueError('Unknown TTA method.')

                # Match keypoints
                if self.matcher == 'cv2_bf':
                    cv_matches = self._cv2_match_fn.match(
                        desc0.cpu().numpy(),
                        desc1.cpu().numpy())
                    matches = np.array([[m.queryIdx, m.trainIdx] for m in cv_matches])
                    mkpts0.append(kpts0.cpu().numpy()[matches[:, 0]])
                    mkpts1.append(kpts1.cpu().numpy()[matches[:, 1]])
                elif self.matcher == 'disk':
                    disk_matches = self._disk_match_fn(desc0, desc1)
                    mkpts0.append(kpts0[disk_matches[0, :]].cpu().numpy())
                    mkpts1.append(kpts1[disk_matches[1, :]].cpu().numpy())
                else:
                    raise RuntimeError(f'Unknown matcher {self.matcher}')

            # Return keypoints
            return np.concatenate(mkpts0), np.concatenate(mkpts1)

disk_matcher = DiskWrapper(matcher='disk', n=2048)

# Utils

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

# Inference

In [None]:
from functools import partial


def match_images(sample_id, batch_id, image_1_id, image_2_id, verbose=0, ret_ims=False):
    img_np1 = cv2.imread(f'{src}/test_images/{batch_id}/{image_1_id}.png')
    img_np2 = cv2.imread(f'{src}/test_images/{batch_id}/{image_2_id}.png')
    st = time.time()

    matchers_cfg = [
        {
            'name': 'loftr',
            'fn': partial(loftr_matcher, tta=[
                'orig', 'flip_lr',
            ]),
        },
        {
            'name': 'superglue',
            'fn': partial(superglue_matcher, tta_groups=[
                ('orig', 'orig'),
                ('orig', 'rot_r10'),
                ('rot_r10', 'orig'),
                ('flip_lr', 'flip_lr'),
            ], forward_type='cross', input_longside=1600)
        },
        {
            'name': 'dkm',
            'fn': partial(dkm_matcher, tta=[
                'orig', 'flip_lr', 'rot_r10'
            ]),
        },
        # {'name': 'disk', 'fn': disk_matcher, 'tta': ['orig', ]},
    ]
    max_name_len = 0
    mkpts0, mkpts1, runtime_str, kp_count_str = [], [], [], []
    for m_cfg in matchers_cfg:
        max_name_len = max(len(m_cfg['name']), max_name_len)
        m_st = time.time()
        m_mkpts0, m_mkpts1 = m_cfg['fn'](img_np1, img_np2)
        m_nd = time.time()

        mkpts0.append(m_mkpts0)
        mkpts1.append(m_mkpts1)
        runtime_str.append(f'{m_cfg["name"].ljust(max_name_len)}: {m_nd - m_st:06f}s')
        kp_count_str.append(f'{m_cfg["name"]}={len(m_mkpts0)}')

    mkpts0 = np.concatenate(mkpts0)
    mkpts1 = np.concatenate(mkpts1)
  
    if verbose >= 1:
        print("  - Matching:")
        for s in runtime_str:
            print("    -", s)
        print("  - Keypoints:", len(mkpts0), '|', ', '.join(kp_count_str))
    if ret_ims:
        return mkpts0, mkpts1, img_np1, img_np2
    else:
        return mkpts0, mkpts1


class SolutionHolder:
    def __init__(self):
        self.F_dict = dict()
    
    @staticmethod
    def solve_keypoints(mkpts0, mkpts1, verbose=0, ret_inliers=False):
        findmat_st = time.time()
        if len(mkpts0) > 7:
            #F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.25, 0.9999, 100000)
            F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.2, 0.9999, 250000)  # EDITED, was 220000
            #F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.1845, 0.999999, 150000)
            inliers = inliers > 0
            assert F.shape == (3, 3), 'Malformed F?'
        else:
            F = np.zeros((3, 3))
        findmat_end = time.time()
        if verbose >= 1:
            print('  - Ransac time:', findmat_end - findmat_st, "s")
        if ret_inliers:
            return F, inliers
        else:
            return F

    def add_solution(self, sample_id, mkpts0, mkpts1, verbose=0):
        self.F_dict[sample_id] = SolutionHolder.solve_keypoints(mkpts0, mkpts1, verbose)
  
    def dump(self, output_file):
        with open(output_file, 'w') as f:
            f.write('sample_id,fundamental_matrix\n')
            for sample_id, F in self.F_dict.items():
                f.write(f'{sample_id},{FlattenMatrix(F)}\n')

# Visualize

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
VISUALIZE = True
import time


if VISUALIZE and len(test_samples) == 3:
    for i, row in enumerate(test_samples[:3]):
        sample_id, batch_id, image_1_id, image_2_id = row
        st = time.time()

        mkpts0, mkpts1, img_np1, img_np2 = match_images(sample_id, batch_id, image_1_id, image_2_id, verbose=1, ret_ims=True)
        F, inliers = SolutionHolder.solve_keypoints(mkpts0, mkpts1, verbose=1, ret_inliers=True)

        gc_st = time.time()
        gc.collect()
        nd = time.time()    
        if (i < 3):
            print("  - gc:", nd - gc_st, "s")
            print("Running time: ", nd - st, " s")
            KMF.draw_LAF_matches(
                KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                             torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                             torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

                KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                             torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                             torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
                torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
                cv2.cvtColor(img_np1, cv2.COLOR_BGR2RGB),
                cv2.cvtColor(img_np2, cv2.COLOR_BGR2RGB),
                inliers,
                draw_dict={'inlier_color': (0.2, 1, 0.2),
                           'tentative_color': None, 
                           'feature_color': (0.2, 0.5, 1), 'vertical': False},
            )

gc.collect()

In [None]:
import time
import threading


def main_solution(test_samples, output_file):
    # Prepare variables
    solution_holder = SolutionHolder()
    mkpts0, mkpts1 = None, None
    mat_calc_thread = None
  
    for i, row in enumerate(test_samples):
        # Print stats only for the first 3 samples
        verbose = 1 if i < 3 else 0
    
        # Parse row
        sample_id, batch_id, image_1_id, image_2_id = row
        cyc_st = time.time()
    
        # Delete previous sample's results
        del mkpts0
        del mkpts1
  
        # Calculate matching pairs
        mkpts0, mkpts1 = match_images(sample_id, batch_id, image_1_id, image_2_id, verbose=verbose)
    
        # If the RANSAC thread is not finished (it should though), wait...
        if mat_calc_thread is not None:
            mat_calc_thread.join()
    
        # Execute a RANSAC thread
        mat_calc_thread = threading.Thread(
            target=solution_holder.add_solution,
            args=(sample_id, mkpts0, mkpts1, verbose))
        mat_calc_thread.start()
    
        # Collect garbage and print logs if required
        cyc_end = time.time()
        gc_st = time.time()
        gc.collect()
        gc_end = time.time()
        if verbose > 0:
            print(f'Iter total: {gc_end - cyc_st:.06f}s  (runtime: {cyc_end - cyc_st:.06f}s, gc: {gc_end - gc_st:.06f}s)')
  
    # Finish and write solution
    fin_st = time.time()
    if mat_calc_thread is not None:
        mat_calc_thread.join()
    fin_end = time.time()
  
    if verbose > 0:
        print('Final calc:', fin_end - fin_st, 's')
    solution_holder.dump(output_file)


main_solution(test_samples, 'submission.csv')