# ***Install Libs***

In [None]:
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl

In [None]:
import cv2
import numpy as np
import pandas as pd
import torch
import kornia as K
import kornia.feature as KF
from kornia_moons.feature import *

import kornia.feature.loftr as loftr

import gc

import torchvision.transforms as transforms
import torch

def sift_image(im0):
    '''Load and format image for SIFT'''
    im0 = cv2.imread(im0)
    scale = 840 / max(im0.shape[0], im0.shape[1]) 
    w = int(im0.shape[1] * scale)
    h = int(im0.shape[0] * scale)
    im0 = cv2.resize(im0, (w, h))
    return im0

def ExtractSiftFeatures(image, detector, num_features):
    '''Compute SIFT features for a given image.'''
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    return detector.detectAndCompute(gray, None)[:num_features]

def ArrayFromCvKps(kps):
    '''Convenience function to convert OpenCV keypoints into a simple numpy array.'''
    return np.array([kp.pt for kp in kps])

def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def fast_cleanup(dp1,dp2,sf1,sf2, device):
    '''
    Sorting function that returns keypoints that are close 
    matches between LoFTr and SIFT.
    '''
    tmp1 = []
    tmp2 = []
    
    tdp1 = torch.Tensor(dp1).to(device)
    tdp2 = torch.Tensor(dp2)
    tsf1 = torch.Tensor(sf1)

    x = torch.stack([tsf1]*tdp1.shape[0]).to(device)
    rr = (torch.linalg.norm(x-tdp1[:, None], dim=2) <= 1).nonzero()
    tmp1 = tdp1[rr[:,0]].cpu().numpy()
    tmp2 = tdp2[rr[:,0]].cpu().numpy()
    
    return tmp1, tmp2


In [None]:
import csv

src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


# ***Inference***

In [None]:
num_features = 8000

mech = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(mech)

matcher = loftr.LoFTR(pretrained=None)
matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
matcher = matcher.to(device).eval()

detector = cv2.SIFT_create(num_features, contrastThreshold=-10000, edgeThreshold=-10000)
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)

In [None]:
F_dict = {}
import time

lesser = False

for i, row in enumerate(test_samples): 
    sample_id, batch_id, image_1_id, image_2_id = row 

    ## load image pairs ##
    image_1 = sift_image(f'{src}/test_images/{batch_id}/{image_1_id}.png')
    image_2 = sift_image(f'{src}/test_images/{batch_id}/{image_2_id}.png')

    ## SIFT ##
    keypoints_1, descriptors_1 = ExtractSiftFeatures(image_1, detector, num_features)
    keypoints_2, descriptors_2 = ExtractSiftFeatures(image_2, detector, num_features)

    ## LoFTr ##
    timage_1 = K.color.bgr_to_rgb(K.image_to_tensor(image_1, False).float() /255.).to(device)
    timage_2 = K.color.bgr_to_rgb(K.image_to_tensor(image_2, False).float() /255.).to(device)
    input_dict = {"image0": K.color.rgb_to_grayscale(timage_1), "image1": K.color.rgb_to_grayscale(timage_2)}
    with torch.no_grad():
        correspondences = matcher(input_dict)

    ## load matching points to cpu ##
    mkpts1 = correspondences['keypoints0'].cpu().numpy()
    mkpts2 = correspondences['keypoints1'].cpu().numpy()

    ### Brute-Force Matching ###
    cv_matches = bf.match(descriptors_1, descriptors_2)
    
    matches = np.array([[m.queryIdx, m.trainIdx] for m in cv_matches])
    cur_kp_1 = ArrayFromCvKps(keypoints_1)
    cur_kp_2 = ArrayFromCvKps(keypoints_2)
    
    
    
    ## Need 8 or more points to reconstruct Fundamental Matrix ##
    if len(mkpts1) > 7:
        if len(cur_kp_1[matches[:, 0]]) > 7:
            '''Both LoFTr and SIFT have enough matches'''
            f1, f2 = fast_cleanup(mkpts1, mkpts2, cur_kp_1[matches[:, 0]], cur_kp_2[matches[:, 1]], device)
            
            if len(f1) > 7:
                '''More than 7 strong matches between LoFTr and SIFT'''
                F, inliers = cv2.findFundamentalMat(f1, f2, cv2.USAC_MAGSAC, 0.1845, 0.999999, 220000)
                inliers = inliers > 0
                assert F.shape == (3, 3), 'Malformed F?'
                F_dict[sample_id] = F
                lesser = 1
            else:
                '''Default to LoFTr if sorting has less than 8 matches'''
                F, inliers = cv2.findFundamentalMat(mkpts1, mkpts2, cv2.USAC_MAGSAC, 0.1845, 0.999999, 220000)
                inliers = inliers > 0
                assert F.shape == (3, 3), 'Malformed F?'
                F_dict[sample_id] = F
                lesser = 0
        
        else:
            '''Default to LoFTr if SIFT has less than 8 matches'''
            F, inliers = cv2.findFundamentalMat(mkpts1, mkpts2, cv2.USAC_MAGSAC, 0.1845, 0.999999, 220000)
            inliers = inliers > 0
            assert F.shape == (3, 3), 'Malformed F?'
            F_dict[sample_id] = F
            lesser = 0
            
    else:    
        if len(cur_kp_1[matches[:, 0]]) > 7:
            '''Default to SIFT if LoFTr has less than 8 matches'''
            F, inliers = cv2.findFundamentalMat(cur_kp_1[matches[:, 0]], cur_kp_2[matches[:, 1]], cv2.USAC_MAGSAC, 0.1845, 0.999999, 220000)
            inliers = inliers > 0
            assert F.shape == (3, 3), 'Malformed F?'
            F_dict[sample_id] = F
            lesser = 0
            
        else:
            '''Zero matrix if both SIFT and LoFTr have less than 8 matches'''
            F_dict[sample_id] = torch.zeros((3,3))
            continue
            
        

    ## Draw keypoint matches across both Image1 and Image2 ##
    if (i < 3):
        if lesser == 0:
            inpt1, inpt2 = mkpts1, mkpts2
        elif lesser == 1:
            inpt1, inpt2 = f1, f2
        elif lesser == 2:
            inpt1, inpt2 = cur_kp_1[matches[:, 0]], cur_kp_2[matches[:, 1]]
            
        
        draw_LAF_matches(KF.laf_from_center_scale_ori(torch.from_numpy(inpt1).view(1,-1, 2),
                                    torch.ones(inpt1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(inpt1.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(inpt2).view(1,-1, 2),
                                    torch.ones(inpt2.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(inpt2.shape[0]).view(1,-1, 1)),
        
        torch.arange(inpt1.shape[0]).view(-1,1).repeat(1,2),
        image_1,
        image_2,
        inliers,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})
    
    
## write fundamental matrix into csv ##
try:
    with open('submission.csv', 'w') as f:
        f.write('sample_id,fundamental_matrix\n')
        for sample_id, F in F_dict.items():
            f.write(f'{sample_id},{FlattenMatrix(F)}\n')
except:
    #check if error is here
    pd.DataFrame().to_csv("submission.csv")