# Hi Kagglers 🙋‍♀️

*Please upvote [original Notebook](https://www.kaggle.com/code/ammarali32/image-matching-challenge-2022-baseline-kornia)*

<center>
    <h2 style="color: #dc3545"> Your UPVOTE can make my day 🤗 </h2>

![](https://storage.googleapis.com/kaggle-media/competitions/google-image-matching/trevi-canvas-licensed-nonoderivs.jpg)

# ***Install Libs***

In [None]:
dry_run = False
!ls /kaggle/input/k/oldufo
!pip install kornia --no-index --find-links=file:///kaggle/input/k/oldufo/imc2022-dependencies/pip/kornia/ --upgrade 
!pip install kornia_moons --no-index --find-links=file:///kaggle/input/k/oldufo/imc2022-dependencies/pip/kornia_moons/ --no-deps  --upgrade 
print('Done!')

# ***Import dependencies***

In [None]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import kornia
import kornia.feature as kornia_feature
from kornia_moons.feature import *

In [None]:
paths = ["/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/AffNet.pth",
        "/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/hardnet8v2.pt",
        "/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/OriNet.pth",
        "/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/HyNet_LIB.pth",
        "/kaggle/input/imgs-matching/tfeat-liberty.params",
        "/kaggle/input/imgs-matching/sosnet_32x32_liberty.pth",
        "/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/keynet_pytorch.pth"
        ]

# ***CFG***

In [None]:
class CFG:
    ori_module = {"pass":{"model":kornia_feature.PassLAF(),"path":None},
                  "orinet":{"model": kornia_feature.LAFOrienter(angle_detector = kornia_feature.OriNet(pretrained=False)),"path":paths[2]}}
    
    aff_module = {"alfaffnet":{"model":kornia_feature.LAFAffNetShapeEstimator(False),"path":paths[0]}}
    
    descriptor = {"hardnet8v2":{"model":kornia_feature.HardNet8(False),"path":paths[1]},
                 "HyNet":{"model":kornia_feature.HyNet(False),"path":paths[3]},
                 "TFeat":{"model":kornia_feature.TFeat(False),"path":paths[4]},
                 "SOSNet":{"model":kornia_feature.SOSNet(False),"path":paths[5]}}
    
    detector_path = paths[-1]
    num_features = 4000
    DescriptorMatcher = "snn" ## ["nn", "snn", "mnn", "smnn"]

# ***Models***

In [None]:
class KeyNetAffNetHardNet(kornia_feature.LocalFeature):
    """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor."""
    def __init__(self,
                 ori_module_name = "pass",
                 aff_module_name = "alfaffnet",
                 descriptor_name = "hardnet8v2",
                 num_features: int = CFG.num_features,
                 upright: bool = True,
                 device: torch.device = torch.device('cuda')):
        ori_module = CFG.ori_module[ori_module_name]["model"]
        detector = kornia_feature.KeyNetDetector(False,
                                  ori_module=ori_module,
                                  aff_module=CFG.aff_module[aff_module_name]["model"].eval()).to(device)
        detector.model.load_state_dict(torch.load(CFG.detector_path)['state_dict'])
        detector.aff.load_state_dict(torch.load(CFG.aff_module[aff_module_name]["path"])['state_dict'])
        if CFG.ori_module[ori_module_name]["path"] is not None:
            detector.ori.angle_detector.load_state_dict(torch.load(CFG.ori_module[ori_module_name]["path"])['state_dict'])
        descriptor = kornia_feature.LAFDescriptor(CFG.descriptor[descriptor_name]["model"],
                                   patch_size=32,
                                   grayscale_descriptor=True).to(device)
        descriptor.descriptor.load_state_dict(torch.load(CFG.descriptor[descriptor_name]["path"]))
        super().__init__(detector, descriptor)

## *Utils*

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


help(draw_LAF_matches)

# We will draw only inliers and tentative matches:
draw_dict={
    'inlier_color': (0.2, 1, 0.2),  # Green: inliers.
    'tentative_color': (1, 1, 0.2, 0.5),  #Light yellow: tentative matches.
    'feature_color': None,
    'vertical': False
}

# ***Inference***

In [None]:
import gc

how_many_to_fill = -1

device = torch.device('cuda')
keynet_affnet_hardnet8 = KeyNetAffNetHardNet().eval()
keynet_affnet_hynet = KeyNetAffNetHardNet(descriptor_name = "HyNet").eval()
keynet_affnet_sosnet = KeyNetAffNetHardNet(descriptor_name = "SOSNet").eval()

matcher = kornia_feature.DescriptorMatcher(CFG.DescriptorMatcher, 0.9)

F_dict = {}
for i, row in enumerate(test_samples):
    sample_id, batch_id, image_1_id, image_2_id = row
    
    if how_many_to_fill >= 0 and i >= how_many_to_fill:
        F_dict[sample_id] = np.random.rand(3, 3)
        continue

    # Load the images.
    image_1 = cv2.cvtColor(cv2.imread(f'{src}/test_images/{batch_id}/{image_1_id}.png'), cv2.COLOR_BGR2RGB)
    image_2 = cv2.cvtColor(cv2.imread(f'{src}/test_images/{batch_id}/{image_2_id}.png'), cv2.COLOR_BGR2RGB)

    # Extract features.
    with torch.no_grad():
        timg1 = kornia.image_to_tensor(image_1, False).float() / 255.
        timg1 = kornia.color.rgb_to_grayscale(timg1).to(device)
        timg2 = kornia.image_to_tensor(image_2, False).float() / 255.
        timg2 = kornia.color.rgb_to_grayscale(timg2).to(device)
        
        lafs1, resps1, descriptors_1 = keynet_affnet_hardnet8(timg1)
        lafs2, resps2, descriptors_2 = keynet_affnet_hardnet8(timg2)
        
        lafs3, resps3, descriptors_3 = keynet_affnet_hynet(timg1)
        lafs4, resps4, descriptors_4 = keynet_affnet_hynet(timg2)
        
        lafs5, resps5, descriptors_5 = keynet_affnet_sosnet(timg1)
        lafs6, resps6, descriptors_6 = keynet_affnet_sosnet(timg2)

        if descriptors_1.size(1) == 0 or descriptors_2.size(1) == 0:
            F_dict[sample_id] = np.zeros((3, 3))
            continue

        dists, idxs  = matcher(descriptors_1[0], descriptors_2[0])
        dists2, idxs2  = matcher(descriptors_3[0], descriptors_4[0])
        dists3, idxs3  = matcher(descriptors_5[0], descriptors_6[0])
        
        cur_kp1 = kornia_feature.get_laf_center(lafs1).detach().cpu().numpy().reshape(-1, 2)
        cur_kp2 = kornia_feature.get_laf_center(lafs2).detach().cpu().numpy().reshape(-1, 2)
        
        cur_kp3 = kornia_feature.get_laf_center(lafs3).detach().cpu().numpy().reshape(-1, 2)
        cur_kp4 = kornia_feature.get_laf_center(lafs4).detach().cpu().numpy().reshape(-1, 2)
        
        cur_kp5 = kornia_feature.get_laf_center(lafs5).detach().cpu().numpy().reshape(-1, 2)
        cur_kp6 = kornia_feature.get_laf_center(lafs6).detach().cpu().numpy().reshape(-1, 2)
        
        cur_kp1 = np.concatenate((cur_kp1, cur_kp3, cur_kp5))
        cur_kp2 = np.concatenate((cur_kp2, cur_kp4, cur_kp6))
        
        match_idxs = np.concatenate((idxs.detach().cpu().numpy(), idxs2.detach().cpu().numpy(), idxs3.detach().cpu().numpy()))

    # Make sure we do not trigger an exception here.
    if len(match_idxs) > 8:
        F, inlier_mask = cv2.findFundamentalMat(cur_kp1[match_idxs[:, 0]], cur_kp2[match_idxs[:, 1]],
                                                cv2.USAC_MAGSAC,
                                                ransacReprojThreshold=0.25,
                                                confidence=0.79,
                                                maxIters=10000)
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()

    if i < 2:
        draw_LAF_matches(lafs1.cpu(), lafs2.cpu(),
                         match_idxs, image_1, image_2,
                         inlier_mask=inlier_mask.astype(np.bool), draw_dict=draw_dict)
        plt.title(f'{image_1_id}-{image_2_id}')
        plt.axis('off')
        plt.show()

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

if dry_run:
    !cat submission.csv

<center>
    <h2 style="color: #dc3545"> Thanks for reading 🤗 </h2>