# XFeat matching example (sparse and semi-dense)

## First, clone repository

In [None]:
# !cd /content && git clone 'https://github.com/verlab/accelerated_features.git'
%cd /Users/vaidehi.som/Documents/accelerated_features
%pwd

## Initialize XFeat

In [None]:
import numpy as np
import os
import sys
sys.path.append(os.path.abspath('..'))  # Add parent directory to Python path
import torch
import torch.nn as nn
import tqdm
import cv2
import matplotlib.pyplot as plt
import imageio as imio
import time
import onnxruntime as ort
from PIL import Image

from modules.xfeat_match_star import XFeat

xfeat_model = XFeat()
xfeat_model.eval()

#Load some example images
im1 = np.copy(imio.v2.imread('../frames/vaidehi_paraglider_multi_object.png')[..., ::-1])
im2 = np.copy(imio.v2.imread('../frames/vaidehi_paraglider_multi_object2.png')[..., ::-1])
print(im1.shape)

In [3]:
def prepare_detect_input(im):
    im_np = np.array(im, dtype=np.float32)

    # Convert RGB to Grayscale using luminosity formula
    im_gray = 0.299 * im_np[:,:,0] + 0.587 * im_np[:,:,1] + 0.114 * im_np[:,:,2]

    # Convert the numpy grayscale image to a PyTorch tensor
    im_tensor = torch.from_numpy(im_gray).unsqueeze(0).unsqueeze(0)  # Adds batch and channel dimensions

    return im_tensor 

def convert_img(im):
    return cv2.cvtColor(np.array(im), cv2.COLOR_RGB2BGR)

## Masking

In [64]:
# Mask input images according to given bounding boxes
def mask_image(image, bboxes):
    """
    Masks the given image using the bounding boxes provided.
    
    Parameters:
    - image: PyTorch tensor of shape (C, H, W)
    - bboxes: List of bounding boxes, each defined as a list of 4 points [top-left, top-right, bottom-left, bottom-right]
              Each point is a list [x, y]
    
    Returns:
    - Masked image as a PyTorch tensor
    """
    # Clone the image to avoid modifying the original image
    masked_image = torch.zeros_like(image)
    # print(masked_image.shape)
    # print(bboxes)
    
    # Process each bounding box
    for box in bboxes:
        # print("box: ", box)
        # Extract coordinates
        xs = [point[0] for point in box]
        ys = [point[1] for point in box]
        # print("xs: ", xs)
        # Determine the bounding rectangle
        x_min, x_max = min(xs), max(xs)
        y_min, y_max = min(ys), max(ys)
        # print("x_min: ", x_min)
        
        # Mask the area
        # print(f"Value before mask at (1,2) in first channel: {masked_image[0, 0, 1, 2]}")
        masked_image[..., y_min:y_max, x_min:x_max] = image[..., y_min:y_max, x_min:x_max]
        # print(f"Value after mask at (1,2) in first channel: {masked_image[0, 0, 1, 2]}")
    
    return masked_image
    

im1_box = [
    [[2516, 1063], [2636, 1063], [2516, 1179], [2636, 1179]],
    [[2831, 1002], [2877, 1002], [2831, 1044], [2877, 1044]],
    [[2903, 725], [2918, 725], [2903, 733], [2918, 733]]
]
im2_box = [
    [[2540, 1083], [2660, 1083], [2540, 1200], [2660, 1200]],
    [[2853, 1020], [2897, 1020], [2853, 1064], [2897, 1064]],
    [[2937, 734], [2950, 734], [2937, 742], [2950, 742]]
]


# Detect

### Session prep

In [None]:
im_gray = prepare_detect_input(im1)

# Define a wrapper module for detect_xfeat
class DetectXFeatWrapper(torch.nn.Module):
    def __init__(self, model):
        super(DetectXFeatWrapper, self).__init__()
        self.model = model

    @torch.inference_mode()
    def forward(self, img):
        return self.model.xfeat_detect(img)

detect_xfeat_wrapper = DetectXFeatWrapper(xfeat_model)

torch.onnx.export(detect_xfeat_wrapper, 
                  (im_gray), 
                  "/Users/vaidehi.som/github/xfeat_star_detect.onnx", 
                  export_params=True, 
                  opset_version=11,  # ONNX version
                  training=torch.onnx.TrainingMode.EVAL,
                  do_constant_folding=True,  # Whether to execute constant folding for optimization
                  input_names=['img'],  # Names of the inputs
                  output_names=['out'],  # Name of the output
                  dynamic_axes={'img': {2: 'height', 3: 'width'}})

### Use session

In [4]:
session_detect = ort.InferenceSession("/Users/vaidehi.som/github/xfeat_star_detect.onnx")

In [None]:

im1_gray = prepare_detect_input(im1)
# im1_gray_masked = mask_image(im1_gray, im1_box)
# print(im1_gray.shape)
# print(im1_gray_masked)
im1_gray_np = im1_gray.cpu().detach().numpy()

inputs = {'img': im1_gray_np}
output1 = session_detect.run(None, inputs)

print(output1[0].shape)
print(output1[1].shape)
print(output1[2].shape)


# Draw keypoints on the images
ref_points = output1[0].squeeze(0)
keypoints1 = [cv2.KeyPoint(x=p[0], y=p[1], size=5) for p in ref_points]

# Draw keypoints on the image
im1_cv = convert_img(im1)
im1_keypoints = cv2.drawKeypoints(im1_cv, keypoints1, None, color=(255, 0, 0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

# Use Matplotlib to display the image
plt.figure(figsize=(10, 10))
plt.imshow(im1_keypoints)
plt.axis('off') 
plt.show()


In [None]:
im2_gray = prepare_detect_input(im2)
# im2_gray_masked = mask_image(im2_gray, im2_box)
im2_gray_np = im2_gray.cpu().detach().numpy()

inputs = {'img': im2_gray_np}
output2 = session_detect.run(None, inputs)

print(output2[0].shape)
print(output2[1].shape)
print(output2[2].shape)


# Draw keypoints on the images
dst_points = output2[0].squeeze(0)
keypoints2 = [cv2.KeyPoint(x=p[0], y=p[1], size=5) for p in dst_points]

# Draw keypoints on the image
im2_cv = convert_img(im2)
im2_keypoints = cv2.drawKeypoints(im2_cv, keypoints2, None, color=(255, 0, 0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

# Use Matplotlib to display the image
plt.figure(figsize=(10, 10))
plt.imshow(im2_keypoints)
plt.axis('off')
plt.show()


## Matching - Semi-dense setting

### Preparing input for session prep and inference

In [None]:

# ids_vec_np = np.random.randn(1, output1[0].shape[1]).astype(np.float32)
dec1_ids_np = np.arange(1, output1[0].shape[1] + 1).astype(np.float32).reshape(1, -1)
dec1_ids_np = np.expand_dims(dec1_ids_np, axis=-1)

dec1_kps_np = output1[0].astype(np.float32)
dec1_desc_np = output1[1].astype(np.float32)
dec1_sc_np = output1[2].astype(np.float32)
dec1_sc_np = np.expand_dims(dec1_sc_np, axis=-1)

dec2_ids_np = np.arange(1, output2[0].shape[1] + 1).astype(np.float32).reshape(1, -1)
dec2_ids_np = np.expand_dims(dec2_ids_np, axis=-1)

dec2_kps_np = output2[0].astype(np.float32)
dec2_desc_np = output2[1].astype(np.float32)

print(dec1_ids_np.shape)
print(dec1_kps_np.shape)
print(dec1_desc_np.shape)
print(dec1_sc_np.shape)
print(dec2_ids_np.shape)
print(dec2_kps_np.shape)
print(dec2_desc_np.shape)

### Session prep

In [8]:
dec1_ids = torch.from_numpy(dec1_ids_np)
dec1_kps = torch.from_numpy(dec1_kps_np)
dec1_desc = torch.from_numpy(dec1_desc_np)
dec1_sc = torch.from_numpy(dec1_sc_np)

dec2_ids = torch.from_numpy(dec2_ids_np)
dec2_kps = torch.from_numpy(dec2_kps_np)
dec2_desc = torch.from_numpy(dec2_desc_np)


# Define a wrapper module for detect_xfeat
class MatchXFeatWrapper(torch.nn.Module):
    def __init__(self, model):
        super(MatchXFeatWrapper, self).__init__()
        self.model = model

    @torch.inference_mode()
    def forward(self, dec1_ids, dec1_kps, dec1_desc, dec1_scales, dec2_ids, dec2_kps, dec2_desc):
       return self.model.match_xfeat_star_onnx(
           dec1_ids, dec1_kps, dec1_desc, dec1_scales, 
           dec2_ids, dec2_kps, dec2_desc)

match_xfeat_wrapper = MatchXFeatWrapper(xfeat_model)

torch.onnx.export(match_xfeat_wrapper, 
                  (dec1_ids, dec1_kps, dec1_desc, dec1_sc, 
                   dec2_ids, dec2_kps, dec2_desc),
                  "/Users/vaidehi.som/github/xfeat_star_match.onnx", 
                  export_params=True, 
                  opset_version=18,  # ONNX version
                  do_constant_folding=True,  # Whether to execute constant folding for optimization
                  input_names = ['dec1_ids', 'dec1_kps', 'dec1_desc', 'dec1_sc',
                                  'dec2_ids', 'dec2_kps', 'dec2_desc'],
                  output_names=['ref_idx', 'ref_pnts', 'target_idx', 'target_pnts'], #, 'homographyMatrix'], # Name of the output
                #   output_names=['ref_pnts', 'target_pnts', 'idx', 'indices'], #, 'homographyMatrix'], # Name of the output
                  dynamic_axes={
                        'dec1_ids': {1: 'kps'},
                        'dec1_kps': {1: 'kps'},
                        'dec1_desc': {1: 'kps'},
                        'dec1_sc': {1: 'kps'},
                        'dec2_ids': {1: 'kps2'},
                        'dec2_kps': {1: 'kps2'},
                        'dec2_desc': {1: 'kps2'},
                        'ref_idx': {0: 'kps3'},
                        'ref_pnts': {0: 'kps3'},
                        'target_idx': {0: 'kps3'},
                        'target_pnts': {0: 'kps3'},
                        # 'indices': {0: 'kps3'},
                        # 'homographyMatrix': {0: 'kps3'}
                    })  

### Use session

In [9]:
session_match = ort.InferenceSession("/Users/vaidehi.som/github/xfeat_star_match.onnx")

In [10]:
def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
    # Calculate the Homography matrix
    H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1000, confidence=0.999)
    mask = mask.flatten() > 0

    # Get corners of the first image (image1)
    h, w = img1.shape[:2]
    corners_img1 = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]], dtype=np.float32).reshape(-1, 1, 2)

    # Warp corners to the second image (image2) space using the Homography matrix
    warped_corners = cv2.perspectiveTransform(corners_img1, H)

    # Draw the warped corners in image2
    img2_with_corners = img2.copy()
    for i in range(len(warped_corners)):
        start_point = tuple(warped_corners[i - 1][0].astype(int))
        end_point = tuple(warped_corners[i][0].astype(int))
        cv2.line(img2_with_corners, start_point, end_point, (0, 255, 0), 4)  # Green color for corners

    # Prepare keypoints from the reference and destination points
    keypoints1 = [cv2.KeyPoint(p[0], p[1], 5) for p in ref_points]
    keypoints2 = [cv2.KeyPoint(p[0], p[1], 5) for p in dst_points]

    # keypoints1 = [cv2.KeyPoint(p[0], p[1], 5) for p in ref_points[mask]]
    # keypoints2 = [cv2.KeyPoint(p[0], p[1], 5) for p in dst_points[mask]]
    # matches = [cv2.DMatch(i, i, 0) for i in range(len(keypoints1))]

    # Prepare matches using the mask to filter inliers
    matches = [cv2.DMatch(i, i, 0) for i in range(len(mask)) if mask[i]]

    # Create a list of colors, each color for each match
    colors = [tuple(np.random.randint(0, 255, 3).tolist()) for _ in matches]

    # Draw matches with individual colors
    img_matches = img1.copy()
    for match, color in zip(matches, colors):
        img1_idx = match.queryIdx
        img2_idx = match.trainIdx

        # Draw circles on the keypoints
        x1, y1 = keypoints1[img1_idx].pt
        x2, y2 = keypoints2[img2_idx].pt
        center1 = (int(x1), int(y1))
        center2 = (int(x2), int(y2))

        # Draw keypoints
        cv2.circle(img_matches, center1, 8, color, -1)  # Filled circle
        cv2.circle(img2_with_corners, center2, 8, color, -1)  # Filled circle

        # Optionally draw lines between matches (comment out if not needed)
        # cv2.line(img_matches, center1, (int(x2 + img1.shape[1]), int(y2)), color, 2)

    # Combine images for display
    img_matches = np.concatenate((img_matches, img2_with_corners), axis=1)

    return img_matches

In [None]:
# Run inference
inputs = {
    'dec1_ids': dec1_ids_np, 
    'dec1_kps': dec1_kps_np, 
    'dec1_desc': dec1_desc_np, 
    'dec1_sc': dec1_sc_np,
    'dec2_ids': dec2_ids_np,
    'dec2_kps': dec2_kps_np, 
    'dec2_desc': dec2_desc_np,
}
outputs = session_match.run(None, inputs)
print(outputs[0].shape)
print(outputs[1].shape)
print(outputs[2].shape)
# print(outputs[3].shape)
mkpts_0 = outputs[1]
mkpts_1 = outputs[3]

idx = outputs[0].squeeze(1)
print(idx[1])
target_idx = outputs[2].squeeze(1)
print(target_idx[1])

# indices = outputs[3].squeeze(1)

# ref_pnts = mkpts_0[indices]
# target_pnts = mkpts_1[indices]
# print(ref_pnts.shape)
# print(target_pnts.shape)
# H = outputs[3]
# H, _ = cv2.findHomography(mkpts_0, mkpts_1, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)


# print("idx: ", idx)
# print("output[0]: ", outputs[0])
# print("mkpts_0: ", mkpts_0)
# print("mkpts_1: ", mkpts_1)



canvas = warp_corners_and_draw_matches(mkpts_0, mkpts_1, im1, im2)
# canvas = warp_corners_and_draw_matches(ref_pnts, target_pnts, im1, im2, H)
plt.figure(figsize=(12,12))
plt.imshow(canvas[..., ::-1]), plt.show()

## Fisheye (Can ignore this for now)

In [None]:
im1 = cv2.imread('/Users/vaidehi.som/Documents/accelerated_features/frames/fisheye/current_frame_10.jpg')
im2 = cv2.imread('/Users/vaidehi.som/Documents/accelerated_features/frames/fisheye/ref_frame_10.jpg')
print(im1.shape)

In [None]:
# undistoring fisheye image
def undistort(image):
    # Define your camera matrix 'K' and distortion coefficients 'D'
    # These values are usually obtained through calibration and are specific to each camera
    K = np.array([[fx, 0, cx],  # fx and cx are the focal length and principal point x-coordinate
                [0, fy, cy],  # fy and cy are the focal length and principal point y-coordinate
                [0, 0, 1]])   # Standard form for camera matrix
    D = np.array([k1, k2, k3, k4])  # Distortion coefficients

    # You need to set the dimension of the undistorted image
    # Usually, it is good to provide the new dimensions of the image.
    dim = (2432, 2160)  # width, height

    # Maps the fisheye image to a new image
    map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, D, np.eye(3), K, dim, cv2.CV_16SC2)
    undistorted_img = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
    return undistorted_img


In [None]:
im1_gray = prepare_detect_input(im1)
im1_gray_np = im1_gray.cpu().detach().numpy()
im1_gray_np = undistort(im1_gray_np)

inputs = {'img': im1_gray_np}
output1 = session_detect.run(None, inputs)

print(output1[0].shape)
print(output1[1].shape)
print(output1[2].shape)


# Draw keypoints on the images
ref_points = output1[0].squeeze(0)
keypoints1 = [cv2.KeyPoint(x=p[0], y=p[1], size=5) for p in ref_points]

# Draw keypoints on the image
im1_cv = convert_img(im1)
im1_keypoints = cv2.drawKeypoints(im1_cv, keypoints1, None, color=(255, 0, 0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

# Use Matplotlib to display the image
plt.figure(figsize=(10, 10))
plt.imshow(im1_keypoints)
plt.axis('off') 
plt.show()

In [None]:
im2_gray = prepare_detect_input(im2)
im2_gray_np = im2_gray.cpu().detach().numpy()

inputs = {'img': im2_gray_np}
output2 = session_detect.run(None, inputs)

print(output2[0].shape)
print(output2[1].shape)
print(output2[2].shape)


# Draw keypoints on the images
dst_points = output2[0].squeeze(0)
keypoints2 = [cv2.KeyPoint(x=p[0], y=p[1], size=5) for p in dst_points]

# Draw keypoints on the image
im2_cv = convert_img(im2)
im2_keypoints = cv2.drawKeypoints(im2_cv, keypoints2, None, color=(255, 0, 0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

# Use Matplotlib to display the image
plt.figure(figsize=(10, 10))
plt.imshow(im2_keypoints)
plt.axis('off')
plt.show()