In [None]:
import sys
sys.path.append('../input/imc-2022-loftr-quadtreeattention')
sys.path.append('../input/imc-2022-loftr-quadtreeattention/QuadTreeAttention')
sys.path.append("../input/super-glue-pretrained-network")

!pip install ../input/imc-2022-loftr-quadtreeattention/packages/torch-1.8.2cu102-cp37-cp37m-linux_x86_64.whl
!pip install ../input/imc-2022-loftr-quadtreeattention/packages/torchvision-0.9.2cu102-cp37-cp37m-linux_x86_64.whl
!pip install ../input/imc-2022-loftr-quadtreeattention/packages/kornia_moons-0.1.9-py3-none-any.whl
!pip install ../input/imc-2022-loftr-quadtreeattention/packages/loguru-0.6.0-py3-none-any.whl
!pip install ../input/imc-2022-loftr-quadtreeattention/packages/einops-0.4.1-py3-none-any.whl
!pip install ../input/imc-2022-loftr-quadtreeattention/packages/timm-0.5.4-py3-none-any.whl

!cp -r ../input/imc-2022-loftr-quadtreeattention/QuadTreeAttention/ ../working/ # input folder is read only
!cd ../working/QuadTreeAttention && pip install .

In [None]:
from kornia_moons.feature import *
import kornia.feature as KF
import kornia as K
import numpy as np
import pydegensac
import torch
import cv2
import csv
import gc

from FeatureMatching.src.lightning.lightning_loftr import PL_LoFTR
from FeatureMatching.src.utils.plotting import make_matching_figure
from config.default import get_cfg_defaults

from models.matching import Matching
from models.utils import (compute_pose_error, compute_epipolar_error,
                          estimate_pose, make_matching_plot,
                          error_colormap, AverageTimer, pose_auc, read_image,
                          rotate_intrinsics, rotate_pose_inplane,
                          scale_intrinsics)

In [None]:
device = torch.device('cuda')

model_loftr = PL_LoFTR(get_cfg_defaults(), pretrained_ckpt='../input/imc-2022-loftr-quadtreeattention/checkpoints/quadtree_outdoor.ckpt')
model_loftr = model_loftr.to(device).eval()

In [None]:
resize = [-1, ]
resize_float = True

config = {
    "superpoint": {
        "nms_radius": 4,
        "keypoint_threshold": 0.005,
        "max_keypoints": 1024
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 20,
        "match_threshold": 0.2,
    }
}
model_sg = Matching(config).eval().to(device)

In [5]:
DIM = (640, 1120)
src = '../input/image-matching-challenge-2022'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def load_torch_image(fname, device):
    img_raw = cv2.imread(fname)

    scale_w = DIM[0] / img_raw.shape[1]
    scale_h = DIM[1] / img_raw.shape[0]

    img_rs = cv2.resize(img_raw, DIM)
    img_rs = K.image_to_tensor(img_rs, False).float() / 255.
    img_rs = K.color.bgr_to_rgb(img_rs)

    img_raw = K.image_to_tensor(img_raw, False).float() / 255.
    img_raw = K.color.bgr_to_rgb(img_raw)

    return img_rs.to(device), img_raw.to(device), scale_w, scale_h

In [None]:
plot = True
F_dict = {}

import time
for i, row in enumerate(test_samples):
    sample_id, batch_id, image_0_id, image_1_id = row
    # Load the images.
    st = time.time()

    # _________________ loftr ___________________
    
    image_0, image_raw0, scale_w0, scale_h0 = load_torch_image(f'{src}/test_images/{batch_id}/{image_0_id}.png', device)
    image_1, image_raw1, scale_w1, scale_h1 = load_torch_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device)
    batch = {"image0": K.color.rgb_to_grayscale(image_0), "image1": K.color.rgb_to_grayscale(image_1)}
    
    with torch.no_grad():
        model_loftr.matcher(batch)
    
    mkpts0 = batch['mkpts0_f'].cpu().numpy()
    mkpts1 = batch['mkpts1_f'].cpu().numpy()
    
    # rearrange original aspect ratio
    mkpts0[:, 0] = mkpts0[:, 0] * (1/scale_w0)
    mkpts0[:, 1] = mkpts0[:, 1] * (1/scale_h0)
    mkpts0 = mkpts0.astype(np.int32)

    mkpts1[:, 0] = mkpts1[:, 0] * (1/scale_w1)
    mkpts1[:, 1] = mkpts1[:, 1] * (1/scale_h1)
    mkpts1 = mkpts1.astype(np.int32)
    
    # _______________ superglue _________________
    
    image_1, inp_1, scales_1 = read_image(f'{src}/test_images/{batch_id}/{image_0_id}.png', device, resize, 0, resize_float)
    image_2, inp_2, scales_2 = read_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device, resize, 0, resize_float)
    
    sg_pred = model_sg({"image0": inp_1, "image1": inp_2})
    sg_pred = {k: v[0].detach().cpu().numpy() for k, v in sg_pred.items()}
    sg_kpts1, sg_kpts2 = sg_pred["keypoints0"], sg_pred["keypoints1"]
    sg_matches, sg_conf = sg_pred["matches0"], sg_pred["matching_scores0"]

    sg_valid = sg_matches > -1
    sg_mkpts0 = sg_kpts1[sg_valid]
    sg_mkpts1 = sg_kpts2[sg_matches[sg_valid]]
    sg_mconf = sg_conf[sg_valid]
    
    # rearrange original aspect ratio
    # sg_mkpts0[:, 0] = sg_mkpts0[:, 0] * (1/scale_w0)
    # sg_mkpts0[:, 1] = sg_mkpts0[:, 1] * (1/scale_h0)
    # sg_mkpts0 = sg_mkpts0.astype(np.int32)

    # sg_mkpts1[:, 0] = sg_mkpts1[:, 0] * (1/scale_w1)
    # sg_mkpts1[:, 1] = sg_mkpts1[:, 1] * (1/scale_h1)
    # sg_mkpts1 = sg_mkpts1.astype(np.int32)
    
    # __________________________________________

    mkpts0 = np.append(mkpts0, sg_mkpts0, axis=0)
    mkpts1 = np.append(mkpts1, sg_mkpts1, axis=0)
    
    if len(mkpts0) > 7:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.2, 0.9999, 100000)
        inliers = inliers > 0
        print(f'{inliers.sum()} MATCHES!')
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        print('Less than 7 points')
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()
    nd = time.time()    
    if (i < 3) and plot:
        print("Running time: ", nd - st, " s")
        draw_LAF_matches(
        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
        torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
        K.tensor_to_image(image_raw0),
        K.tensor_to_image(image_raw1),
        inliers,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})
    
with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')