## Welcome Back Kagglers 

credit : https://www.kaggle.com/code/ammarali32/imc-2022-kornia-loftr-from-0-533-to-0-721


<center>
    <h2 style="color: #022047"> If you found it useful please upvote  </h2>
</center>

![](https://storage.googleapis.com/kaggle-media/competitions/google-image-matching/trevi-canvas-licensed-nonoderivs.jpg)

# ***Install Libs***

In [None]:
dry_run = False
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl

# ***Import dependencies***

In [None]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import kornia
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF
import gc


# ***Model***

In [None]:
device = torch.device('cuda')
matcher = KF.LoFTR(pretrained=None)
matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
matcher = matcher.to(device).eval().half()

## *Utils*

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def load_torch_image(fname):
    img = cv2.imread(fname)
#     img = img[:img.shape[1]] # memory check
    orig_h, orig_w = img.shape[:2]
    scale = 1200 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    img = K.image_to_tensor(img, False).float() /255.
    img = K.color.bgr_to_rgb(img)
    return img, (orig_w, orig_h), (w, h)

# ***Inference***

In [None]:
!nproc

In [None]:
from multiprocessing import Pool

alpha = 15.0
angle: torch.tensor = torch.ones(1) * alpha
scale: torch.tensor = torch.ones(1, 2)

def preprocess(row):
    sample_id, batch_id, image_1_id, image_2_id = row
    image_1, orig_sz1, sz1 = load_torch_image(f'{src}/test_images/{batch_id}/{image_1_id}.png')
    center: torch.tensor = torch.ones(1, 2)
    center[..., 0] = sz1[0] / 2  # x
    center[..., 1] = sz1[1] / 2  # y
    M: torch.tensor = K.geometry.get_rotation_matrix2d(center, angle, scale) # 1x2x3
    Minv: torch.tensor = K.geometry.get_rotation_matrix2d(center, -angle, scale)
    image_1_rot = K.geometry.warp_affine(image_1, M.to(image_1.device), dsize=(sz1[1], sz1[0]))
    image_1_irot = K.geometry.warp_affine(image_1, Minv.to(image_1.device), dsize=(sz1[1], sz1[0]))
    image_1_flip = K.geometry.transform.hflip(image_1)

    image_2, orig_sz2, sz2 = load_torch_image(f'{src}/test_images/{batch_id}/{image_2_id}.png')
    image_2_rot = K.geometry.warp_affine(image_2, M.to(image_2.device), dsize=(sz2[1], sz2[0]))
    image_2_irot = K.geometry.warp_affine(image_2, Minv.to(image_2.device), dsize=(sz2[1], sz2[0]))
    image_2_flip = K.geometry.transform.hflip(image_2)
    
    return {
        'sample_id': sample_id,
        'image_1_info': (orig_sz1, sz1),
        'image_2_info': (orig_sz2, sz2),
        'images_1': K.color.rgb_to_grayscale(torch.cat((image_1, image_1_rot, image_1_irot, image_1_flip))),
        'images_2': K.color.rgb_to_grayscale(torch.cat((image_2, image_2_rot, image_2_irot, image_2_flip))),
        'image_1': image_1,
        'image_2': image_2,
        'affine_mats': (M, Minv),
        'st': time.perf_counter()
    }

def _matching(model, feat_c0, feat_c1, feat_f0, feat_f1, data):
    feat_c0, feat_c1 = model.loftr_coarse(feat_c0, feat_c1, None, None)
    
    # 3. match coarse-level
    model.coarse_matching(feat_c0, feat_c1, data, mask_c0=None, mask_c1=None)

    # 4. fine-level refinement
    feat_f0_unfold, feat_f1_unfold = model.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
    if feat_f0_unfold.size(0) != 0:  # at least one coarse level predicted
        feat_f0_unfold, feat_f1_unfold = model.loftr_fine(feat_f0_unfold, feat_f1_unfold)

    # 5. match fine-level
    model.fine_matching(feat_f0_unfold, feat_f1_unfold, data)

def deep(data):
    with torch.no_grad():
        M, Minv = data['affine_mats']
        orig_sz1, sz1 = data['image_1_info']
        orig_sz2, sz2 = data['image_2_info']
        c0s, f0s = matcher.backbone(data['images_1'].to(device).half())
        c1s, f1s = matcher.backbone(data['images_2'].to(device).half())
        meta = {
            'bs': 1,
            'hw0_i': [sz1[1], sz1[0]],
            'hw1_i': [sz2[1], sz2[0]],
            'hw0_c': c0s.shape[2:], 'hw1_c': c1s.shape[2:],
            'hw0_f': f0s.shape[2:], 'hw1_f': f1s.shape[2:],
        }

        mkpts0, mkpts1, batch_indexes = [], [], []
        queries = [(0, 0), (1, 0), (2, 0), (3, 3), (0, 1), (0, 2)]
        for bi, (i, j) in enumerate(queries):
            feat_c0 = c0s[i:i+1]
            feat_c1 = c1s[j:j+1]
            feat_f0 = f0s[i:i+1]
            feat_f1 = f1s[j:j+1]
            
            feat_c0 = matcher.pos_encoding(feat_c0).permute(0, 2, 3, 1)
            n, h, w, c = feat_c0.shape
            feat_c0 = feat_c0.reshape(n, -1, c)

            feat_c1 = matcher.pos_encoding(feat_c1).permute(0, 2, 3, 1)
            n1, h1, w1, c1 = feat_c1.shape
            feat_c1 = feat_c1.reshape(n1, -1, c1)
            _matching(matcher, feat_c0, feat_c1, feat_f0, feat_f1, meta)
            
            mkpts0.append(meta['mkpts0_f'].float().cpu().numpy())
            mkpts1.append(meta['mkpts1_f'].float().cpu().numpy())
            batch_indexes.append(np.full(len(meta['mkpts0_f']), bi, dtype=np.int32))
    mkpts0 = np.concatenate(mkpts0)
    mkpts1 = np.concatenate(mkpts1)
    batch_indexes = np.concatenate(batch_indexes)

    mask = (batch_indexes == 1)
    mkpts0[mask] = np.concatenate([mkpts0[mask], np.ones((np.sum(mask), 1))], axis=1) @ Minv[0].cpu().numpy().T
    mask = (batch_indexes == 2)
    mkpts0[mask] = np.concatenate([mkpts0[mask], np.ones((np.sum(mask), 1))], axis=1) @ M[0].cpu().numpy().T
    mask = (batch_indexes == 3)
    mkpts0[mask, 0] = sz1[0] - mkpts0[mask, 0]
    mkpts1[mask, 0] = sz2[0] - mkpts1[mask, 0]
    mask = (batch_indexes == 4)
    mkpts1[mask] = np.concatenate([mkpts1[mask], np.ones((np.sum(mask), 1))], axis=1) @ Minv[0].cpu().numpy().T
    mask = (batch_indexes == 5)
    mkpts1[mask] = np.concatenate([mkpts1[mask], np.ones((np.sum(mask), 1))], axis=1) @ M[0].cpu().numpy().T


    mkpts0_orig = mkpts0 / sz1 * orig_sz1
    mkpts1_orig = mkpts1 / sz2 * orig_sz2

    data['mkpts0'] = mkpts0
    data['mkpts1'] = mkpts1
    data['mkpts0_orig'] = mkpts0_orig
    data['mkpts1_orig'] = mkpts1_orig
    return data

def ransac(data):
    if len(data['mkpts0']) > 7:
        F, inliers = cv2.findFundamentalMat(
            data['mkpts0_orig'], data['mkpts1_orig'],
            cv2.USAC_MAGSAC, 0.150, 0.9999, 250000)
        inliers = inliers > 0
        assert F.shape == (3, 3), 'Malformed F?'
        data['F'] = F
        data['inliers'] = inliers
    else:
        data['F'] = np.zeros((3, 3))
        data['inliers'] = []
    return data


In [None]:
F_dict = {}

import time

processed_samples = map(preprocess, test_samples)
matched_samples = map(deep, processed_samples)

with Pool(2) as pool:
    for i, data in enumerate(pool.imap(ransac, matched_samples, chunksize=4)):
        F_dict[data['sample_id']] = data['F']
        inliers = data['inliers']
        nd = time.perf_counter()
        if (i < 3) and len(inliers) > 0:
            mkpts0 = data['mkpts0']
            mkpts1 = data['mkpts1']
            image_1 = data['image_1']
            image_2 = data['image_2']

            print(f"Running time: {nd-data['st']:.3f}s")
            draw_LAF_matches(
            KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                        torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                        torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

            KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                        torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                        torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
            torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
            K.tensor_to_image(image_1),
            K.tensor_to_image(image_2),
            inliers,
            draw_dict={'inlier_color': (0.2, 1, 0.2),
                       'tentative_color': None, 
                       'feature_color': (0.2, 0.5, 1), 'vertical': False})

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

<center>
    <h2 style="color: #022047"> Thanks for reading 🤗  </h2>
</center>