In [1]:
import os
import numpy as np
import cv2
from glob import glob
import torch
import matplotlib.pyplot as plt
import gc
import h5py
import torch
from collections import defaultdict
from tqdm import tqdm
from copy import deepcopy
import warnings
from functools import partial
import fastprogress
from tqdm import tqdm as progress_bar
import time
import random 
from collections import defaultdict

if not torch.cuda.is_available():
    print('You may want to enable the GPU switch?')

INSTALLED_LOG = {}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Install Kornia
force_kornialoftr_reinstall = False

if 'KorniaLoFTR' not in INSTALLED_LOG or force_kornialoftr_reinstall:
    dry_run = False
    !pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
    !pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl
    INSTALLED_LOG['KorniaLoFTR'] = True
else:
    print('Already installed KorniaLoFTR. Set "force_kornialoftr_reinstall=True" to override this behavior.')
    

# Import and use Kornia
import kornia
import kornia as K
import kornia.feature as KF
import kornia_moons.feature as KMF


class LoFTRMatcher:
    def __init__(self, device=None, input_longside=1200, conf_th=None):
        self._loftr_matcher = KF.LoFTR(pretrained=None)
        self._loftr_matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
        self._loftr_matcher = self._loftr_matcher.to(device).eval()
        self.device = device
        self.input_longside = input_longside
        self.conf_thresh = conf_th
        
    # Prepares an image for LoFTR:    
    def prep_img(self, img, long_side=1200):
        if long_side is not None: # Resizes the image to a specified long_side
            scale = long_side / max(img.shape[0], img.shape[1]) 
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
            img = cv2.resize(img, (w, h))
        else:
            scale = 1.0

        img_ts = K.image_to_tensor(img, False).float() / 255.
        img_ts = K.color.bgr_to_rgb(img_ts)
        img_ts = K.color.rgb_to_grayscale(img_ts)
        return img, img_ts.to(self.device), scale
    
    """
    The function tta_rotation_preprocess :
    
    Test-Time Augmentation is a technique used to enhance the performance 
    of a model by augmenting the input data during inference.
    
    is used to augment the input image by rotating it and preparing it for further processing. 
    It computes both the rotation and inverse rotation matrices, 
    applies the rotation, and converts the image into a format suitable for a neural network. 
    """
    def tta_rotation_preprocess(self, img_np, angle):
        
        # Performs Test-Time Augmentation (TTA) with image rotation for a given angle:
        
        # Computes the rotation matrix.
        
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        # Computes the inverse rotation matrix
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        # Applies the rotation to the image
        rot_img = cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0]))

        # Converts the rotated image to a tensor.
        rot_img_ts = K.image_to_tensor(rot_img, False).float() / 255.
        rot_img_ts = K.color.bgr_to_rgb(rot_img_ts)
        rot_img_ts = K.color.rgb_to_grayscale(rot_img_ts)
        return rot_M, rot_img_ts.to(self.device), rot_M_inv

    """
    The purpose of this function is to correct the locations of keypoints detected 
    in a rotated image by transforming them back to their positions in the original image. 
    This is achieved by applying the inverse of the rotation matrix used during preprocessing. 
    
    Additionally, the function creates a mask to identify keypoints 
    that remain within the image boundaries after the inverse transformation. 
    This process ensures that keypoints are accurately mapped back to their original positions, 
    which is essential for subsequent image matching or processing tasks.
    """
    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        
        # Adjusts keypoint locations after applying rotation TTA:
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask

    # This is the main function for performing feature matching between two images using LoFTR with optional TTA:
    def __call__(self, img_np1, img_np2, tta=['orig']):
        # tta: List of TTA methods to be applied.
        with torch.no_grad():
            img_np1, img_ts0, scale0 = self.prep_img(img_np1, long_side=self.input_longside)
            img_np2, img_ts1, scale1 = self.prep_img(img_np2, long_side=self.input_longside)
            images0, images1 = [], []

            # Apply TTA (Test-Time Augmentation):
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_ts0, img_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_ts1, [2, ])
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np1, 10)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np2, 10)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np1, -10)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np2, -10)
                else:
                    raise ValueError('Unknown TTA method.')
                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            # Inference with LoFTR:
            input_dict = {"image0": torch.cat(images0), "image1": torch.cat(images1)}
            correspondences = self._loftr_matcher(input_dict)
            mkpts0 = correspondences['keypoints0'].cpu().numpy()
            mkpts1 = correspondences['keypoints1'].cpu().numpy()
            batch_id = correspondences['batch_indexes'].cpu().numpy()
            confidence = correspondences['confidence'].cpu().numpy()

            # Reverse TTA Adjustments:
            for idx, tta_elem in enumerate(tta):
                batch_mask = batch_id == idx

                if tta_elem == 'orig':
                    pass
                elif tta_elem == 'flip_lr':
                    mkpts0[batch_mask, 0] = img_np1.shape[1] - mkpts0[batch_mask, 0]
                    mkpts1[batch_mask, 0] = img_np2.shape[1] - mkpts1[batch_mask, 0]
                elif tta_elem == 'flip_ud':
                    mkpts0[batch_mask, 1] = img_np1.shape[0] - mkpts0[batch_mask, 1]
                    mkpts1[batch_mask, 1] = img_np2.shape[0] - mkpts1[batch_mask, 1]
                elif tta_elem == 'rot_r10':
                    mkpts0[batch_mask], mask0 = self.tta_rotation_postprocess(mkpts0[batch_mask], img_np1, rot_r10_M0_inv)
                    mkpts1[batch_mask], mask1 = self.tta_rotation_postprocess(mkpts1[batch_mask], img_np2, rot_r10_M1_inv)
                    confidence[batch_mask] += (~(mask0 & mask1)).astype(np.float32) * -10.
                elif tta_elem == 'rot_l10':
                    mkpts0[batch_mask], mask0 = self.tta_rotation_postprocess(mkpts0[batch_mask], img_np1, rot_l10_M0_inv)
                    mkpts1[batch_mask], mask1 = self.tta_rotation_postprocess(mkpts1[batch_mask], img_np2, rot_l10_M1_inv)
                    confidence[batch_mask] += (~(mask0 & mask1)).astype(np.float32) * -10.
                else:
                    raise ValueError('Unknown TTA method.')
            
            # Filter Keypoints by Confidence Threshold:
            if self.conf_thresh is not None:
                th_mask = confidence >= self.conf_thresh
            else:
                th_mask = confidence >= 0.
            mkpts0, mkpts1 = mkpts0[th_mask, :], mkpts1[th_mask, :]

            # Matching points
            return mkpts0 / scale0, mkpts1 / scale1
        

# 1200 is the validation size, according to the paper

loftr_matcher = LoFTRMatcher(device=device, input_longside=1200, conf_th=0.9)

Processing /kaggle/input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
Installing collected packages: kornia
  Attempting uninstall: kornia
    Found existing installation: kornia 0.7.2
    Uninstalling kornia-0.7.2:
      Successfully uninstalled kornia-0.7.2
Successfully installed kornia-0.6.4
Processing /kaggle/input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl
Installing collected packages: kornia-moons
Successfully installed kornia-moons-0.1.9


In [3]:
# Install superglue
force_superglue_reinstall = False

if 'superglue' not in INSTALLED_LOG or force_superglue_reinstall:
    !mkdir /tmp/superpoint
    !cp -r ../input/super-glue-pretrained-network/models /tmp/superpoint/superpoint
    !ls /tmp/superpoint/superpoint
    !touch /tmp/superpoint/superpoint/__init__.py
    INSTALLED_LOG['superglue'] = True
else:
    print('Already installed SuperGlue. Set "force_superglue_reinstall=True" to override this behavior.')

# Import superglue
import sys
sys.path.append("/tmp/superpoint")
from superpoint.superpoint import SuperPoint
from superpoint.superglue import SuperGlue


class SuperGlueCustomMatchingV2(torch.nn.Module):
    """ Image Matching Frontend (SuperPoint + SuperGlue) """
    def __init__(self, config={}, device=None):
        super().__init__()
        self.superpoint = SuperPoint(config.get('superpoint', {}))
        self.superglue = SuperGlue(config.get('superglue', {}))

        self.tta_map = {
            'orig': self.untta_none,
            'eqhist': self.untta_none,
            'clahe': self.untta_none,
            'flip_lr': self.untta_fliplr,
            'flip_ud': self.untta_flipud,
            'rot_r10': self.untta_rotr10,
            'rot_l10': self.untta_rotl10,
            'fliplr_rotr10': self.untta_fliplr_rotr10,
            'fliplr_rotl10': self.untta_fliplr_rotl10
        }
        self.device = device

    def forward_flat(self, data, ttas=['orig', ], tta_groups=[['orig']]):
        """ Run SuperPoint (optionally) and SuperGlue
        SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
        Args:
          data: dictionary with minimal keys: ['image0', 'image1']
        """
        # 1. Initialization:
        pred = {}

        # 2. Extract SuperPoint (keypoints, scores, descriptors) if not provided
        # sp_st = time.time()
        if 'keypoints0' not in data:
            pred0 = self.superpoint({'image': data['image0']})
            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
        if 'keypoints1' not in data:
            pred1 = self.superpoint({'image': data['image1']})
            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
        # sp_nd = time.time()
        # print('SP:', sp_nd - sp_st, 's')

        # 3. Reverse TTA (Test-Time Augmentation):
        # The scores from SuperPoint are converted to lists to enable modification:    
        pred['scores0'] = list(pred['scores0'])
        pred['scores1'] = list(pred['scores1'])
        
        for i in range(len(pred['keypoints0'])): 
            # The corresponding TTA transformation function is applied to reverse the transformation
            # This includes flipping, rotating, or any other augmentation applied during inference
            pred['keypoints0'][i], pred['descriptors0'][i], pred['scores0'][i] = self.tta_map[ttas[i]](
                pred['keypoints0'][i], pred['descriptors0'][i], pred['scores0'][i],
                # The inplace parameter controls whether the original data is modified or a copy is returned
                # Illegal keypoints (outside the image dimensions) can be optionally masked out.
                w=data['image0'].shape[3], h=data['image0'].shape[2], inplace=True, mask_illegal=True)
                
            # This process is repeated for both images
            pred['keypoints1'][i], pred['descriptors1'][i], pred['scores1'][i] = self.tta_map[ttas[i]](
                pred['keypoints1'][i], pred['descriptors1'][i], pred['scores1'][i],
                w=data['image1'].shape[3], h=data['image1'].shape[2], inplace=True, mask_illegal=True)

        # 4. Batch all features :
        
        # We should either have i) one image per batch, or
        # ii) the same number of local features for all images in the batch.
        data = {**data, **pred}
        
        # The extracted features (keypoints, descriptors, and scores) are merged with the input data to form a batch.
        # This step ensures that the data has consistent dimensions required for further processing.
        
        # 5. Grouping Predictions:
        group_preds = []
        for tta_group in tta_groups:
            # Create a boolean mask to filter relevant transformations
            group_mask = torch.from_numpy(np.array([x in tta_group for x in ttas], dtype=np.bool))
            # Select group-specific data based on the mask
            group_data = {
                **{f'keypoints{k}': [data[f'keypoints{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'descriptors{k}': [data[f'descriptors{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'scores{k}': [data[f'scores{k}'][i] for i in range(len(ttas)) if ttas[i] in tta_group] for k in [0, 1]},
                **{f'image{k}': data[f'image{k}'][group_mask, ...] for k in [0, 1]},
            }
            
            # Concatenate data along appropriate axis if it's in list/tuple format
            for k, v in group_data.items():
                if isinstance(group_data[k], (list, tuple)):
                    if k.startswith('descriptor'):
                        group_data[k] = torch.cat(group_data[k], 1)[None, ...]
                    else:
                        group_data[k] = torch.cat(group_data[k])[None, ...]
                else:
                    group_data[k] = torch.flatten(group_data[k], 0, 1)[None, ...]
            # sg_st = time.time()
            
            # Pass the grouped data through the SuperGlue module to compute matches and matching scores
            group_pred = {
                # **{k: group_data[k] for k in group_data},
                **group_data,
                **self.superglue(group_data)
            }
            # sg_nd = time.time()
            # print('SG:', sg_nd - sg_st, 's')
            group_preds.append(group_pred)
        return group_preds

    def forward_cross(self, data, ttas=['orig', ], tta_groups=[('orig', 'orig')]):
        pred = {}

        # Extract SuperPoint (keypoints, scores, descriptors) if not provided
        sp_st = time.time()
        if 'keypoints0' not in data:
            pred0 = self.superpoint({'image': data['image0']})
            pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
        if 'keypoints1' not in data:
            pred1 = self.superpoint({'image': data['image1']})
            pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
        sp_nd = time.time()

        # Batch all features
        # We should either have i) one image per batch, or
        # ii) the same number of local features for all images in the batch.
        data = {**data, **pred}

        # Group predictions (list, with elements with matches{0,1}, matching_scores{0,1} keys)
        group_pred_list = []
        tta2id = {k: i for i, k in enumerate(ttas)}
        for tta_group in tta_groups:
            group_idx = tta2id[tta_group[0]], tta2id[tta_group[1]]
            group_data = {
                **{f'image{i}': data[f'image{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'keypoints{i}': data[f'keypoints{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'descriptors{i}': data[f'descriptors{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'scores{i}': data[f'scores{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
            }

            for k in group_data:
                if isinstance(group_data[k], (list, tuple)):
                    group_data[k] = torch.stack(group_data[k])

            group_sg_pred = self.superglue(group_data)
            group_pred_list.append(group_sg_pred)

        # UnTTA
        data['scores0'] = list(data['scores0'])
        data['scores1'] = list(data['scores1'])
        for i in range(len(data['keypoints0'])):
            data['keypoints0'][i], data['descriptors0'][i], data['scores0'][i] = self.tta_map[ttas[i]](
                data['keypoints0'][i], data['descriptors0'][i], data['scores0'][i],
                w=data['image0'].shape[3], h=data['image0'].shape[2], inplace=True, mask_illegal=False)

            data['keypoints1'][i], data['descriptors1'][i], data['scores1'][i] = self.tta_map[ttas[i]](
                data['keypoints1'][i], data['descriptors1'][i], data['scores1'][i],
                w=data['image1'].shape[3], h=data['image1'].shape[2], inplace=True, mask_illegal=False)

        # Sooo... groups?
        for group_pred, tta_group in zip(group_pred_list, tta_groups):
            group_idx = tta2id[tta_group[0]], tta2id[tta_group[1]]
            group_pred.update({
                **{f'keypoints{i}': data[f'keypoints{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
                **{f'scores{i}': data[f'scores{i}'][group_idx[i]:group_idx[i]+1] for i in [0, 1]},
            })
        return group_pred_list


    def untta_none(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        if not inplace:
            keypoints = keypoints.clone()
        return keypoints, descriptors, scores
    
    def untta_fliplr(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # It horizontally flips the keypoints. 
        # It subtracts each x-coordinate of the keypoints from the width of the image minus 1 to simulate a horizontal flip.
        if not inplace:
            keypoints = keypoints.clone()
        keypoints[:, 0] = w - keypoints[:, 0] - 1.
        return keypoints, descriptors, scores

    def untta_flipud(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # It vertically flips the keypoints. 
        # It subtracts each y-coordinate of the keypoints from the height of the image minus 1 to simulate a vertical flip.
        if not inplace:
            keypoints = keypoints.clone()
        keypoints[:, 1] = h - keypoints[:, 1] - 1.
        return keypoints, descriptors, scores

    def untta_rotr10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # It rotates the keypoints clockwise by 10 degrees around the center of the image. 
        # It uses a rotation matrix to perform the rotation.
        # rotr10 is +10, inverse is -10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), -15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores

    def untta_rotl10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        #  It rotates the keypoints counterclockwise by 10 degrees around the center of the image. 
        # It also uses a rotation matrix to perform the rotation.
        # rotr10 is -10, inverse is +10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), 15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores
        
    def untta_fliplr_rotr10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        # It first horizontally flips the keypoints and then rotates them clockwise by 10 degrees. 
        # It combines the transformations by applying them sequentially.
        # rotr10 is +10, inverse is -10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), -15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        rot_kpts[:, 0] = w - rot_kpts[:, 0] - 1.
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores

    def untta_fliplr_rotl10(self, keypoints, descriptors, scores, w, h, inplace=True, mask_illegal=True):
        #  It first horizontally flips the keypoints and then rotates them counterclockwise by 10 degrees. 
        # Similar to the previous function, it combines the transformations by applying them sequentially.
        # rotr10 is -10, inverse is +10
        rot_M_inv = torch.from_numpy(cv2.getRotationMatrix2D((w / 2, h / 2), 15, 1)).to(torch.float32).to(self.device)
        ones = torch.ones_like(keypoints[:, 0])
        hom = torch.cat([keypoints, ones[:, None]], 1)
        rot_kpts = torch.matmul(rot_M_inv, hom.T).T[:, :2]
        rot_kpts[:, 0] = w - rot_kpts[:, 0] - 1.
        if mask_illegal:
            mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < w) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < h)
            return rot_kpts[mask], descriptors[:, mask], scores[mask]
        else:
            return rot_kpts, descriptors, scores


class SuperGlueMatcherV2:
    def __init__(self, device=None, conf_th=None):
        config = {
            "superpoint": {
                "nms_radius": 2,  # Reduce the NMS radius to filter out less distinctive keypoints
                "keypoint_threshold": 0.2,  # Lower the keypoint detection threshold to capture only highly distinctive keypoints
                "max_keypoints": 3000,  # Decrease the maximum number of keypoints to focus on the most prominent features
            },
            "superglue": {
                "weights": "outdoor", # Use outdoor weights for better matching of natural scenes
                "sinkhorn_iterations": 100, # specifies the number of iterations to perform during the Sinkhorn normalization process
                "match_threshold": 0.2,  # Lower the matching threshold for more conservative matches, focusing on highly confident matches
            }
        }
        self.device = device
        self._superglue_matcher = SuperGlueCustomMatchingV2(
            config=config, device=self.device,
            ).eval().to(device)

        self.conf_thresh = conf_th
    
    def prep_np_img(self, img, long_side=None):
        if long_side is not None:
            scale = long_side / max(img.shape[0], img.shape[1])
            w = int(img.shape[1] * scale)
            h = int(img.shape[0] * scale)
            img = cv2.resize(img, (w, h))
        else:
            scale = 1.0
        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), scale
    
    def frame2tensor(self, frame):
        return (torch.from_numpy(frame).float()/255.)[None, None].to(self.device)
            
        
    def tta_rotation_preprocess(self, img_np, angle):
        rot_M = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), angle, 1)
        rot_M_inv = cv2.getRotationMatrix2D((img_np.shape[1] / 2, img_np.shape[0] / 2), -angle, 1)
        rot_img = self.frame2tensor(cv2.warpAffine(img_np, rot_M, (img_np.shape[1], img_np.shape[0])))
        return rot_M, rot_img, rot_M_inv

    def tta_rotation_postprocess(self, kpts, img_np, rot_M_inv):
        ones = np.ones(shape=(kpts.shape[0], ), dtype=np.float32)[:, None]
        hom = np.concatenate([kpts, ones], 1)
        rot_kpts = rot_M_inv.dot(hom.T).T[:, :2]
        mask = (rot_kpts[:, 0] >= 0) & (rot_kpts[:, 0] < img_np.shape[1]) & (rot_kpts[:, 1] >= 0) & (rot_kpts[:, 1] < img_np.shape[0])
        return rot_kpts, mask

    def __call__(self, img_np0, img_np1, tta_groups=[['orig']], forward_type='cross', input_longside=None):
        with torch.no_grad():
            img_np0, scale0 = self.prep_np_img(img_np0, input_longside)
            img_np1, scale1 = self.prep_np_img(img_np1, input_longside)

            img_ts0 = self.frame2tensor(img_np0)
            img_ts1 = self.frame2tensor(img_np1)
            images0, images1 = [], []

            tta = []
            for tta_g in tta_groups:
                tta += tta_g
            tta = list(set(tta))

            # TTA
            for tta_elem in tta:
                if tta_elem == 'orig':
                    img_ts0_aug, img_ts1_aug = img_ts0, img_ts1
                elif tta_elem == 'flip_lr':
                    img_ts0_aug = torch.flip(img_ts0, [3, ])
                    img_ts1_aug = torch.flip(img_ts1, [3, ])
                elif tta_elem == 'flip_ud':
                    img_ts0_aug = torch.flip(img_ts0, [2, ])
                    img_ts1_aug = torch.flip(img_ts1, [2, ])
                elif tta_elem == 'rot_r10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np0, 15)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np1, 15)
                elif tta_elem == 'rot_l10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np0, -15)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np1, -15)
                elif tta_elem == 'fliplr_rotr10':
                    rot_r10_M0, img_ts0_aug, rot_r10_M0_inv = self.tta_rotation_preprocess(img_np0[:, ::-1], 15)
                    rot_r10_M1, img_ts1_aug, rot_r10_M1_inv = self.tta_rotation_preprocess(img_np1[:, ::-1], 15)
                elif tta_elem == 'fliplr_rotl10':
                    rot_l10_M0, img_ts0_aug, rot_l10_M0_inv = self.tta_rotation_preprocess(img_np0[:, ::-1], -15)
                    rot_l10_M1, img_ts1_aug, rot_l10_M1_inv = self.tta_rotation_preprocess(img_np1[:, ::-1], -15)
                elif tta_elem == 'eqhist':
                    img_ts0_aug = self.frame2tensor(cv2.equalizeHist(img_np0))
                    img_ts1_aug = self.frame2tensor(cv2.equalizeHist(img_np1))
                elif tta_elem == 'clahe':
                    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
                    img_ts0_aug = self.frame2tensor(clahe.apply(img_np0))
                    img_ts1_aug = self.frame2tensor(clahe.apply(img_np1))
                else:
                    raise ValueError('Unknown TTA method.')

                images0.append(img_ts0_aug)
                images1.append(img_ts1_aug)

            # Inference
            if forward_type == 'cross':
                pred = self._superglue_matcher.forward_cross(
                    data={
                        "image0": torch.cat(images0),
                        "image1": torch.cat(images1)
                    },
                    ttas=tta, tta_groups=tta_groups)
            elif forward_type == 'flat':
                pred = self._superglue_matcher.forward_flat(
                data={
                    "image0": torch.cat(images0),
                    "image1": torch.cat(images1)
                },
                ttas=tta, tta_groups=tta_groups)
            else:
                raise RuntimeError(f'Unknown forward_type {forward_type}')

            mkpts0, mkpts1, mconf = [], [], []
            for group_pred in pred:
                pred_aug = {k: v[0].detach().cpu().numpy().squeeze() for k, v in group_pred.items()}
                kpts0, kpts1 = pred_aug["keypoints0"], pred_aug["keypoints1"]
                matches, conf = pred_aug["matches0"], pred_aug["matching_scores0"]

                if self.conf_thresh is None:
                    valid = matches > -1
                else:
                    valid = (matches > -1) & (conf >= self.conf_thresh)
                mkpts0.append(kpts0[valid])
                mkpts1.append(kpts1[matches[valid]])
                mconf.append(conf[valid])

            cat_mkpts0 = np.concatenate(mkpts0)
            cat_mkpts1 = np.concatenate(mkpts1)
            mask0 = (cat_mkpts0[:, 0] >= 0) & (cat_mkpts0[:, 0] < img_np0.shape[1]) & (cat_mkpts0[:, 1] >= 0) & (cat_mkpts0[:, 1] < img_np0.shape[0])
            mask1 = (cat_mkpts1[:, 0] >= 0) & (cat_mkpts1[:, 0] < img_np1.shape[1]) & (cat_mkpts1[:, 1] >= 0) & (cat_mkpts1[:, 1] < img_np1.shape[0])
            return cat_mkpts0[mask0 & mask1] / scale0, cat_mkpts1[mask0 & mask1] / scale1


# 1600 is the validation size in the paper
superglue_matcher = SuperGlueMatcherV2(device=device, conf_th=0.5)


__init__.py  matching.py  superglue.py	superpoint.py  utils.py  weights
Loaded SuperPoint model
Loaded SuperGlue model ("outdoor" weights)


In [4]:
warnings.filterwarnings("ignore", category=UserWarning)

def get_unique_idxs(A, dim=1):
    if A.size(dim) == 0:
        return torch.tensor([], dtype=torch.long, device=A.device)
    
    unique, idx, counts = torch.unique(A, dim=dim, sorted=True, return_inverse=True, return_counts=True)
    _, ind_sorted = torch.sort(idx, stable=True)
    
    if dim == 0:
        cum_sum = counts.cumsum(0)
        cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device), cum_sum[:-1]))
        first_indicies = ind_sorted[cum_sum]
    else:  # dim == 1
        cum_sum = counts.cumsum(0)
        cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device), cum_sum[:-1]))
        first_indicies = ind_sorted[cum_sum]
    
    return first_indicies



def match_images(img_path1, img_path2):
    img_np1 = cv2.imread(img_path1)
    img_np2 = cv2.imread(img_path2)
    
    matchers_cfg = [
        {
            'name': 'loftr',
            'fn': partial(loftr_matcher, tta=['orig', 'flip_lr']),
        },
        {
            'name': 'superglue',
            'fn': partial(superglue_matcher, tta_groups=[
                ('orig', 'orig'), ('orig', 'rot_r10'), ('rot_r10', 'orig'), ('flip_lr', 'flip_lr')
            ], forward_type='cross', input_longside=1600)
        },
    ]
    
    mkpts0, mkpts1, kp_sources = [], [], []

    for m_cfg in matchers_cfg:
        m_mkpts0, m_mkpts1 = m_cfg['fn'](img_np1, img_np2)
        # print(f"Number of matches found by {m_cfg['name']}: {len(m_mkpts0)}")
        mkpts0.append(m_mkpts0)
        mkpts1.append(m_mkpts1)
        kp_sources.append([m_cfg['name']] * len(m_mkpts0))
    
    mkpts0 = np.concatenate(mkpts0)
    mkpts1 = np.concatenate(mkpts1)
    kp_sources = np.concatenate(kp_sources)

    return mkpts0, mkpts1, kp_sources, img_np1, img_np2

In [5]:
# Main script
dirname = '/kaggle/input/miawww'
img_fnames = [os.path.join(dirname, x) for x in os.listdir(dirname) if x.endswith('.jpg')]

pairs_within_dir = [(i, j) for i in range(len(img_fnames)) for j in range(i + 1, len(img_fnames))]

feature_dir = '/kaggle/working/'
os.makedirs(feature_dir, exist_ok=True)

In [None]:
# Initialize dictionaries and default dictionaries
kpts = defaultdict(list)
match_indexes = defaultdict(dict)
total_kpts = defaultdict(int)
unique_kpts = {}
unique_match_idxs = {}
out_match = defaultdict(dict)

# Iterate over pairs and process
for pair_idx in tqdm(pairs_within_dir):
    idx1, idx2 = pair_idx
    img_path1, img_path2 = img_fnames[idx1], img_fnames[idx2]
    img_name1 = os.path.basename(img_path1)  # Extract filename from path
    img_name2 = os.path.basename(img_path2)  # Extract filename from path
    mkpts0, mkpts1, kp_sources, img_np1, img_np2 = match_images(img_path1, img_path2)

    # Get unique keypoints
    mkpts0_torch = torch.from_numpy(mkpts0)
    mkpts1_torch = torch.from_numpy(mkpts1)
    unique_idxs = get_unique_idxs(mkpts0_torch, dim=0)
    unique_mkpts0 = mkpts0[unique_idxs.numpy()]
    unique_mkpts1 = mkpts1[unique_idxs.numpy()]

    # Store unique keypoints in defaultdict
    kpts[img_name1].append(unique_mkpts0)
    kpts[img_name2].append(unique_mkpts1)

    # Store match indices in defaultdict
    current_match = torch.arange(len(unique_idxs)).reshape(-1, 1).repeat(1, 2)
    current_match[:, 0] += total_kpts[img_name1]
    current_match[:, 1] += total_kpts[img_name2]

    # Update total keypoints count
    total_kpts[img_name1] += len(unique_idxs)
    total_kpts[img_name2] += len(unique_idxs)

    match_indexes[img_name1][img_name2] = current_match.numpy().tolist()


 57%|█████▋    | 2773/4851 [43:34<32:40,  1.06it/s]  

In [None]:
# Convert lists of keypoints to numpy arrays
for key in kpts:
    kpts[key] = np.concatenate(kpts[key])

# Finding Unique Keypoints
for k in kpts.keys():
    uniq_kps, uniq_reverse_idxs = torch.unique(torch.from_numpy(kpts[k]), dim=0, return_inverse=True)
    unique_match_idxs[k] = uniq_reverse_idxs
    unique_kpts[k] = uniq_kps.numpy()

In [None]:
# Process match indexes to remove duplicates
for k1, group in match_indexes.items():
    for k2, m in group.items():
        m2 = deepcopy(m)
        m2 = np.array(m2)  # Convert to numpy array for indexing

        if m2.ndim == 1:
            m2 = m2.reshape(-1, 2)  # Ensure m2 is 2-dimensional

        if m2.shape[0] == 0:
            # Skip empty matches
            continue

        if k1 in unique_match_idxs and k2 in unique_match_idxs:
            unique_match_k1 = unique_match_idxs[k1]
            unique_match_k2 = unique_match_idxs[k2]
        else:
            continue  # Skip if unique_match_idxs do not contain required keys

        m2[:, 0] = unique_match_k1[np.array(m2[:, 0], dtype=int)]
        m2[:, 1] = unique_match_k2[np.array(m2[:, 1], dtype=int)]

        mkpts = np.concatenate([unique_kpts[k1][m2[:, 0]], unique_kpts[k2][m2[:, 1]]], axis=1)

        # Convert mkpts to PyTorch tensor before passing to get_unique_idxs
        mkpts_tensor = torch.from_numpy(mkpts)
        unique_idxs_current = get_unique_idxs(mkpts_tensor, dim=0)

        # Convert m2 to PyTorch tensor before indexing
        m2 = torch.from_numpy(m2)
        m2_semiclean = m2[unique_idxs_current]
        unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
        m2_semiclean = m2_semiclean[unique_idxs_current1]
        unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
        m2_semiclean2 = m2_semiclean[unique_idxs_current2]
        out_match[k1][k2] = m2_semiclean2.tolist()  # Convert back to list for HDF5 saving


In [None]:
# Open an HDF5 file in write mode
file_path = "keypoints.h5"
with h5py.File(file_path, 'w') as f:
    try:
        for image_path, keypoints_array in unique_kpts.items():
            # Extract image file name from full path
            image_filename = os.path.basename(image_path)

            # Create a dataset named after the image filename
            f.create_dataset(image_filename, data=keypoints_array.astype(np.float32))

            # Optionally, you can also set attributes if needed:
            f[image_filename].attrs['description'] = f'Keypoints for image {image_filename}'

            print(f"Dataset: {image_filename}, Shape: {keypoints_array.shape}, Dtype: {keypoints_array.dtype}")

        print(f'Keypoints saved to {file_path}')
    
    except Exception as e:
        print(f"Error saving matches to HDF5 file: {e}")

In [None]:
# Open an HDF5 file in write mode
file_path = "matches.h5"
with h5py.File(file_path, 'w') as f:
    for image_filename, keypoints_array in unique_kpts.items():
        group = f.create_group(image_filename)
        for matched_image_filename, matches in out_match[image_filename].items():
            group.create_dataset(matched_image_filename, data=np.array(matches).astype(np.float32))

print(f'Matches saved to {file_path}')

In [None]:
def load_image(image_path):
    """Loads an image from file."""
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Image not found at path: {image_path}")
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

def visualize_keypoints(image_dir, keypoints, num_images=5, point_size=1):
    sample_keys = random.sample(list(keypoints.keys()), num_images)
    fig, axes = plt.subplots(1, num_images, figsize=(50, 10))
    for ax, key in zip(axes, sample_keys):
        img_path = os.path.join(image_dir, key)
        img = load_image(img_path)
        kp = keypoints[key]
        ax.imshow(img)
        ax.scatter(kp[:, 0], kp[:, 1], s=point_size, c='red', marker='o')
        ax.axis('off')
        ax.set_title(key)  # Add the image name as the title
    plt.show()

# Load keypoints from the saved file
def load_keypoints(feature_dir):
    with h5py.File(f'{feature_dir}/keypoints.h5', 'r') as f:
        keypoints = {k: f[k][...] for k in f.keys()}
    return keypoints

In [None]:
# Load keypoints and visualize
keypoints = load_keypoints(feature_dir)
visualize_keypoints(dirname, keypoints)

In [None]:
!ls

In [None]:
os.chdir(r'/kaggle/working')

In [None]:
%cd /kaggle/working

In [None]:
from IPython.display import FileLink
FileLink(r"matches.h5")