## Thin Plate Spline [1], LoFTR [2] and CRF Post-Processing [3] for image matching and warping

### References
[1] Donato, G., & Belongie, S. J. (2003). Approximation methods for thin plate spline mappings and principal warps. Department of Computer Science and Engineering, University of California, San Diego. - https://github.com/cheind/py-thin-plate-spline

[2] Sun, Jiaming and Shen, Zehong and Wang, Yuang and Bao, Hujun and Zhou, Xiaowei (2021). LoFTR: Detector-Free Local Feature Matching with Transformers - https://github.com/zju3dv/LoFTR

[3] http://web.archive.org/web/20161023180357/http://www.philkr.net/home/densecrf, https://github.com/lucasb-eyer/pydensecrf

## Setup and installation

You should have TPS_LoFTR folder created in your drive. 




In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/SemesterProj/TPS_LoFTR

Install libraries and download weights for both indoor and outdoor pretrained models on MegaDepth dataset.

**IMPORTANT**: The original LoFTR code has been altered, therefore you do not need to clone the original LoFTR repo anymore. However, you could still do this by running the commands below:
```
!git clone https://github.com/zju3dv/LoFTR --depth 1
!mv LoFTR/* . && rm -rf LoFTR
```

In [None]:
# Install libraries
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install torch einops yacs kornia

In [None]:
# Download pretrained weights for LoFTR
!mkdir weights 
%cd weights/
!gdown --id 1w1Qhea3WLRMS81Vod_k5rxS_GNRgIi-O  # indoor-ds
!gdown --id 1M-VD35-qdB5Iw-AtbDBCKC7hPolFW9UY  # outdoor-ds
%cd ..

Import all the needed libraries.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from PIL import Image
import torch
import cv2
from src.utils.plotting import make_matching_figure
from src.loftr import LoFTR, default_cfg
import thinplate as tps
import os
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, unary_from_labels
import tqdm

## Inference

### Helper functions

In [None]:
def show_warped(img, warped):
    """
    Plot images.
    """
    fig, axs = plt.subplots(1, 2, figsize=(16, 8))
    axs[0].imshow(img[...,::-1], origin='upper')
    axs[1].imshow(warped[...,::-1], origin='upper')
    plt.show()


def warp_image_cv(img, c_src, c_dst, dshape=None, interp='cubic'):
    """
    Function for warping satellite labels based on matched point sets in
    satellite and drone images. Uses Thin-Plate-Spline transformation.
    """

    dshape = dshape or img.shape

    # Find parameters
    theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)

    # Create grid and remap according to interpolation method
    grid = tps.tps_grid(theta, c_dst, dshape)
    mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)

    if interp == 'cubic':
        return cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC)
    elif interp == 'nn':
        return cv2.remap(img, mapx, mapy, cv2.INTER_NEAREST)


def threshold_img(img_hsv):
    """ 
    Function for color thresholding and categorical conversion, based on HSV colorspace.
    """

    # Define range of blue color in HSV
    lower_blue = np.array([115, 50, 50])
    upper_blue = np.array([125, 255, 255])
    mask_blue = cv2.inRange(img_hsv, lower_blue, upper_blue)

    # Define range of green color in HSV
    lower_green = np.array([55, 50, 50])
    upper_green = np.array([65, 255, 255])
    mask_green = cv2.inRange(img_hsv, lower_green, upper_green)

    # Define range for red lower mask (0-10)
    lower_red = np.array([0, 50, 50])
    upper_red = np.array([5, 255, 255])
    mask0_red = cv2.inRange(img_hsv, lower_red, upper_red)

    # Define range for red upper mask (170-180)
    lower_red = np.array([175, 50, 50])
    upper_red = np.array([180, 255, 255])
    mask1_red = cv2.inRange(img_hsv, lower_red, upper_red)

    # Join red masks
    mask_red = mask0_red + mask1_red

    # Create categorical output based on mask
    output_hsv = np.zeros((img_hsv.shape[0], img_hsv.shape[1]))
    output_hsv[np.where(mask_red == 255)] = 1
    output_hsv[np.where(mask_blue == 255)] = 2
    output_hsv[np.where(mask_green == 255)] = 3

    return output_hsv.astype(np.uint8)


def draw_matches(mconf, mkpts0, mkpts1, img0_raw, img1_raw, download_pdf=False):
    """ Draw matches found by LoFTR on satellite and drone images."""
    
    # Define colors and texts for the plot
    color = cm.jet(mconf, alpha=0.7)
    text = [
        'LoFTR',
        'Matches: {}'.format(len(mkpts0)),
    ]
    fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, mkpts0, mkpts1, text)

    # A high-res PDF will also be downloaded automatically.
    if download_pdf:
        make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, mkpts0, mkpts1, text, path="LoFTR-colab-demo.pdf")
        files.download("LoFTR-colab-demo.pdf")
    else:
        make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, mkpts0, mkpts1, text)#, path="LoFTR-colab-demo.pdf")

### Configuration and setup

In [None]:
# All data directory
data_dir = '../data'

# Current dataset directory
dataset_dir = '120m'

# Satellite images directory name
img_dir = "img_dir"

# Labels directory name
ann_dir = "ann_dir"

# Drone images directory name
drone_dir = 'drone_dir'

# Maximum number of matches to keep from LoFTR (after fine grained module)
nr_conf = 1000

# Visualizations of matched points and results
show = False

# Inference image size - (640, 480) BY DEFAULT IN ORIGINAL LoFTR
inference_im_size = (480, 360)

# Output image size
output_im_size = (5280, 3956)
fullsize = (output_im_size[0] == 5280 and output_im_size[1] == 3956)

# Palette and classes
classes = ['background', 'building', 'road', 'water']
palette = [[0, 0, 0], [255, 0, 0], [0, 0, 255], [0, 255, 0]]

# Define the list of drone image names that you want to warp
files = sorted(os.listdir(os.path.join(dataset_root, "human_drone_ann_dir")))

# Interpolation method:
    # 'nn' - for nearest neighbour
    # 'cubic' - for cubic interpolation and HSV color thresholding
interp_type = 'cubic'

# Directory where to save the warped labels ( and create if not exists)
output_dir = os.path.join(dataset_root, 
                          'loftr_warped{}_infer_{}_{}_{}_{}p'.format(
                              '_fullsize' if fullsize else '', interp_type,
                               inference_im_size[0], inference_im_size[1], nr_conf)
                          )

# Create directories and paths
dir_name = os.path.expanduser(output_dir)
os.makedirs(dir_name, mode=0o777, exist_ok=True)
data_root = os.path.join(data_dir, dataset_dir)
img_root = os.path.join(data_root, img_dir)
ann_root = os.path.join(data_root, ann_dir)
drone_root = os.path.join(data_root, drone_dir)

The default config uses dual-softmax.
The outdoor and indoor models share the same config.
You can change the default values like ```image_type``` and ```thr``` in the config dictionary.

In [None]:
# Set type of model - indoor/outdoor
image_type = 'outdoor'

# Set threshold for coarse matching module - default 0.2
default_cfg['match_coarse']['thr'] = 0.85

# Window size search space on confidence matrix
    # Conf_matrix will be of size: im_size / 8, therefore a window size of im_size / 4 will cover the whole
    # confidence matrix because the window is defined around each point in the confidence matrix.
    # Now it is set to whole confidence matrix, so no filtering (1 is maximum)
filter_window_size = 1
default_cfg['match_coarse']['filter_area'] = str(int(inference_im_size[0] / 4 / filter_window_size)) 

# Define model
matcher = LoFTR(config=default_cfg)
if image_type == 'indoor':
  matcher.load_state_dict(torch.load("weights/indoor_ds.ckpt")['state_dict'])
elif image_type == 'outdoor':
  matcher.load_state_dict(torch.load("weights/outdoor_ds.ckpt")['state_dict'])
else:
  raise ValueError("Wrong image_type is given.")
matcher = matcher.eval().cuda()

### Inference and warping

In [None]:
for i, filename in enumerate(files):
    img_name = filename.split('.')[0]

    # Define drone and satellite image paths and create pair
    img0_pth = os.path.join(drone_root, '{}.jpg'.format(img_name))
    img1_pth = os.path.join(img_root, '{}.jpg'.format(img_name))
    image_pair = [img0_pth, img1_pth]

    # Read and resize images
    img0_raw = cv2.imread(image_pair[0], cv2.IMREAD_GRAYSCALE)
    img1_raw = cv2.imread(image_pair[1], cv2.IMREAD_GRAYSCALE)
    img0_raw = cv2.resize(img0_raw, inference_im_size)
    img1_raw = cv2.resize(img1_raw, inference_im_size)

    # Normalize images and create batch format
    img0 = torch.from_numpy(img0_raw)[None][None].cuda() / 255.
    img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255.
    batch = {'image0': img0, 'image1': img1}

    # Inference with LoFTR and get prediction
    with torch.no_grad():
        matcher(batch)
        mkpts0 = batch['mkpts0_f'].cpu().numpy()
        mkpts1 = batch['mkpts1_f'].cpu().numpy()
        mconf = batch['mconf'].cpu().numpy()

    if show:
        # Plot matches
        draw_matches(mconf, mkpts0, mkpts1, img0_raw, img1_raw)
    
    # Get satellite label and resize
    if interp_type == 'nn':
        img = Image.open(os.path.join(ann_root, '{}.png'.format(img_name)))
        img = np.array(img.resize(inference_im_size, Image.NEAREST))
    else:
        img = np.array(Image.open(os.path.join(ann_root, '{}.png'.format(img_name))).convert('RGB'))[:, :, ::-1]
        img = cv2.resize(img, inference_im_size)

    # Get indexes of the first nr_conf most confident matches
    indexes = np.argsort(mconf)[::-1][:nr_conf]

    # Define source and destination points for TPS alorithm and normalize them
    c_src = mkpts1[indexes, :]
    c_dst = mkpts0[indexes, :]

    c_src[:, 0] /= img.shape[1]
    c_src[:, 1] /= img.shape[0]
    c_dst[:, 0] /= img.shape[1]
    c_dst[:, 1] /= img.shape[0]

    print("There are {} keypoints for {}!".format(len(c_src), img_name))
    
    # Do the warping only if there are at lest 4 matched points.
    # Otherwise, the warping step won't work out.
    if len(c_src) >= 4:
        warped = warp_image_cv(img, c_src, c_dst, dshape=None, interp=interp_type)
        if inference_im_size != output_im_size:
            if interp_type == 'nn':
                warped = np.array(Image.fromarray(warped).resize(output_im_size, Image.NEAREST))
            else:
                warped = cv2.resize(warped, output_im_size, interpolation=cv2.INTER_CUBIC)
    else:
        if inference_im_size != output_im_size:
            if interp_type == 'nn':
                warped = np.array(Image.fromarray(warped).resize(output_im_size, Image.NEAREST))
            else:
                warped = cv2.resize(img, output_im_size, interpolation=cv2.INTER_CUBIC)
        print("No warping")

    if show:
        # Plot overlayed images
        drone_img = cv2.imread(image_pair[0])
        drone_img = cv2.resize(drone_img, output_im_size)
        if interp_type == 'nn':
            label = Image.open(os.path.join(ann_root,'{}.png'.format(img_name)))
            label = np.array(label.resize(output_im_size, Image.NEAREST).convert('RGB'))[:, :, ::-1]
            warped_show = Image.fromarray(warped).convert('P')
            warped_show.putpalette(np.array(palette, dtype=np.uint8))
            warped_show = np.array(warped_show.convert('RGB'))[:, :, ::-1]
        else:
            label = np.array(Image.open(os.path.join(ann_root,'{}.png'.format(img_name))).convert('RGB'))[:, :, ::-1]
            label = cv2.resize(label, output_im_size, interpolation=cv2.INTER_CUBIC)
            warped_show = warped.copy()
        drone_warped = cv2.addWeighted(drone_img, 0.6, warped_show, 0.7, 0)
        drone_satlabel = cv2.addWeighted(drone_img, 0.6, label, 0.7, 0)
        show_warped(drone_warped, drone_satlabel)

    if interp_type == 'nn':
        seg_img = Image.fromarray(warped).convert('P')
        seg_img.putpalette(np.array(palette, dtype=np.uint8))
    else:
        # Threshold warped image on HSV colorspace and convert to 'P' format
        warped = cv2.cvtColor(warped, cv2.COLOR_BGR2HSV)
        thresh_drone = threshold_img(warped)
        seg_img = Image.fromarray(thresh_drone).convert('P')
        seg_img.putpalette(np.array(palette, dtype=np.uint8))
    
    # Save result
    seg_img.save(os.path.join(output_dir, filename))

## CRF- Post Processing

More on CRF at:  https://github.com/lucasb-eyer/pydensecrf

### Configuration

In [None]:
# CRF parameters
compat_gauss = 3 # Higher, for smoother mask
sxy_gauss = 3
comapt_bil = 10
sxy_bil = 80 # Pixel proximity, care about change of intensity 
srgb_bil = 13 # Discard difference in pixel intensity if higher, the mask gets cut off if too high

# Number of iterations of the algorithm
iters = 10

# The certainty of the ground-truth (or reference label), must be between (0, 1) - manually chosen
gtprob = 0.7

# Path to labels
ann_path = '../data/120m/loftr_warped_fullsize'

# Path to drone images
drone_path = '../data/120m/drone_dir'

### Post-process

In [None]:
files = sorted(os.listdir(ann_path))

for i, filename in enumerate(files):
    # Read drone image and reference annotation
    img = np.array(Image.open(os.path.join(drone_path, filename.split('.')[0] + '.jpg')))
    ann = np.array(Image.open(os.path.join(ann_path, filename)))

    # Compute new image
    U = unary_from_labels(ann, 4, gtprob, zero_unsure=False)
    d = dcrf.DenseCRF2D(ann.shape[1], ann.shape[0], 4)
    d.setUnaryEnergy(U)
    d.addPairwiseGaussian(sxy=(sxy_gauss,sxy_gauss), compat=compat_gauss, kernel=dcrf.DIAG_KERNEL, normalization = dcrf.NORMALIZE_SYMMETRIC)
    d.addPairwiseBilateral(sxy=(sxy_bil,sxy_bil), srgb=(srgb_bil,srgb_bil,srgb_bil), rgbim=img, compat=comapt_bil, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
    Q = d.inference(iters)
    map = np.argmax(Q, axis=0).reshape((ann.shape[0], ann.shape[1]))
    proba = np.array(map)

    # Convert to 'P' format and save
    seg_img = Image.fromarray(proba.astype(np.uint8)).convert('P')
    seg_img.putpalette(np.array(palette, dtype=np.uint8))
    seg_img.save('test{}_after{}_prob{}_compat_gauss{}_sxy_gauss{}_sxy_bil{}_srgb_bil{}.png'.format(i, str(iters), str(gtprob).split('.')[1], str(compat_gauss), str(sxy_gauss), str(sxy_bil), str(srgb_bil)))
    print("{} done!".format(filename))
    