In [1]:
import matplotlib.pyplot as plt
from pathlib import Path
import scipy.io as spio
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import os
from utils import ImageNode


%load_ext autoreload
%autoreload 2

# 1. Load images

In [2]:
DISTANCE_THRESHOLD_PERCENTILE = 75
RANSAC_THRESHOLD = 5
SCALE_FACTOR = 0.5

In [3]:
data_path = Path(os.getcwd()) / 'volley'

kps_fpaths = sorted((data_path / 'input').glob('kp*.mat'),
                     key=lambda x: int(x.stem.split('_')[1])
                    )

imgs_fpaths = sorted((data_path / 'input').glob('*.jpg'),
                     key=lambda x: int(x.stem.split('_')[1])
                    )

ref_kps_fpath = data_path / 'reference' / 'kp_ref.mat'
ref_img_fpath = data_path / 'reference' / 'img_ref.jpg'

In [4]:
keypoints = [spio.loadmat(fpath)['kp'] for fpath in kps_fpaths]
descriptors = [spio.loadmat(fpath)['desc'] for fpath in kps_fpaths]

ref_keypoints = spio.loadmat(ref_kps_fpath)['kp']
ref_descriptors = spio.loadmat(ref_kps_fpath)['desc']

images = [np.array(Image.open(fpath)) for fpath in imgs_fpaths]
ref_img = np.array(Image.open(ref_img_fpath))

In [5]:
from utils import resize_keypoint_and_image


for i in range(len(keypoints)):
    keypoints[i], images[i] = resize_keypoint_and_image(keypoints[i], images[i], SCALE_FACTOR)

ref_keypoints, ref_img = resize_keypoint_and_image(ref_keypoints, ref_img, SCALE_FACTOR)


In [6]:
for i in range(len(keypoints)):
    mask = (keypoints[i][:, 1] > 18) & (keypoints[i][:, 1] < 140)
    keypoints[i] = keypoints[i][mask]
    descriptors[i] = descriptors[i][mask]

mask = (ref_keypoints[:, 1] > 18) & (ref_keypoints[:, 1] < 140)
filtered_ref_keypoints = ref_keypoints[mask]
filtered_ref_descriptors = ref_descriptors[mask]


In [7]:
image_nodes = [ImageNode(image, kp, desc, idx) for idx, (image, kp, desc) in enumerate(zip(images, keypoints, descriptors))]
image_node_ref = ImageNode(ref_img, filtered_ref_keypoints, filtered_ref_descriptors, idx=-1)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image_nodes[0].image, cmap='gray')
plt.scatter(image_nodes[0].keypoints[:, 0], image_nodes[0].keypoints[:, 1], c='r', s=10)
plt.title("Filtered Keypoints on first Image")
plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(ref_img, cmap='gray')
plt.scatter(filtered_ref_keypoints[:, 0], filtered_ref_keypoints[:, 1], c='r', s=10)
plt.title("Filtered Keypoints on Reference Image")
plt.axis('off')
plt.show()

# 2. Compute thresholds

In [10]:
def find_matches(src_node, dst_node, max_desc_dist=None):
    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
    matches = bf.match(src_node.descriptors, dst_node.descriptors)
    
    if max_desc_dist != None:
        matches = [m for m in matches if m.distance <= max_desc_dist]
    
    src_pts = np.float32([src_node.keypoints[m.queryIdx] for m in matches])
    dst_pts = np.float32([dst_node.keypoints[m.trainIdx] for m in matches])

    return matches, src_pts, dst_pts

In [11]:
def max_descriptor_distance_threshold(image_nodes, distance_threshold_percentile=50):
    all_thresholds = np.array([])
    pair_matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)

    for i in tqdm(range(len(image_nodes) - 1)):
        src_desc = image_nodes[i].descriptors
        dst_desc = image_nodes[i+1].descriptors
        matches = pair_matcher.match(src_desc, dst_desc)
        distances = np.array([m.distance for m in matches])
        curr_threshold = np.percentile(distances, distance_threshold_percentile)
        all_thresholds = np.append(all_thresholds, curr_threshold)

    descriptor_distance_threshold = max(all_thresholds)
    return descriptor_distance_threshold

In [12]:
def compute_min_inliers_threshold(image_nodes, max_desc_dist):
    inliers_per_pair = []

    for i in range(len(image_nodes) - 1):
        matches, src_pts, dst_pts = find_matches(image_nodes[i], image_nodes[i+1], max_desc_dist)
        H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, RANSAC_THRESHOLD)
        inliers = np.sum(mask)
        inliers_per_pair.append(inliers)

    min_inliers = min(inliers_per_pair)
    return min_inliers

In [None]:
max_desc_dist = max_descriptor_distance_threshold(image_nodes, DISTANCE_THRESHOLD_PERCENTILE)
print('The maximum descriptor distance is:', round(max_desc_dist, 2))

In [None]:
min_inliers_threshold = compute_min_inliers_threshold(image_nodes, max_desc_dist)
print('The minimum number of inliers is:', min_inliers_threshold)

# 3. Compute homographies

In [15]:
def find_valid_homography(src_node, dst_node, min_inliers_threshold, max_desc_dist):
    matches, src_pts, dst_pts = find_matches(src_node, dst_node, max_desc_dist)
    
    H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, RANSAC_THRESHOLD)
    num_inliers = np.sum(mask)
    is_H_valid = 1 if num_inliers >= min_inliers_threshold else 0
    return H, is_H_valid, num_inliers

In [None]:
def find_direct_homographies(image_nodes, image_node_ref, min_inliers_threshold, max_desc_dist):
    valid_homographies = {1: []}

    for i in tqdm(range(len(image_nodes))):
        H, is_H_valid, num_inliers = find_valid_homography(image_nodes[i], image_node_ref, min_inliers_threshold, max_desc_dist)
        
        if is_H_valid:
            image_nodes[i].homography = H
            image_nodes[i].footprint.append('ref')
            valid_homographies[1].append(i)
    return valid_homographies

valid_homographies = find_direct_homographies(image_nodes, image_node_ref, min_inliers_threshold, max_desc_dist)

In [None]:
from itertools import chain

def find_remaining_homographies(image_nodes, valid_homographies, min_inliers_threshold, max_desc_dist):
    level = 1
    while level in valid_homographies:
        best_idxs = valid_homographies[level]
        for node in tqdm(image_nodes):
            if node.idx in list(chain.from_iterable(valid_homographies.values())):
                continue

            nearest_idx = min(best_idxs, key=lambda x: abs(x - node.idx))
            H, is_H_valid, num_inliers = find_valid_homography(node, image_nodes[nearest_idx], min_inliers_threshold, max_desc_dist)

            if is_H_valid:
                node.homography = np.dot(image_nodes[nearest_idx].homography, H) # Double check if this math is correct
                valid_homographies.setdefault(level + 1, []).append(node.idx)
                node.footprint.append(nearest_idx)
                node.footprint.extend(image_nodes[nearest_idx].footprint)
        level += 1

find_remaining_homographies(image_nodes, valid_homographies, min_inliers_threshold, max_desc_dist)

In [18]:
def show_keypoints_matches(node1, node2, max_desc_dist):
    """
    Show the keypoints matches between two image nodes.

    Args:
        node1 (ImageNode): The first image node.
        node2 (ImageNode): The second image node.
    """
    # Convert numpy keypoints to cv2.KeyPoint objects
    matches, src_pts, dst_pts = find_matches(node1, node2, max_desc_dist)
    keypoints1 = [cv2.KeyPoint(x[0], x[1], 1) for x in src_pts]
    keypoints2 = [cv2.KeyPoint(x[0], x[1], 1) for x in dst_pts]

    matches = sorted(matches, key=lambda x: x.distance)

    # Draw the matches
    img_matches = cv2.drawMatches(node1.image, keypoints1, node2.image, keypoints2, matches[:50], None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)

    # Display the matches
    plt.figure(figsize=(20, 10))
    plt.imshow(img_matches)
    plt.title(f"Keypoints Matches between Image {node1.idx} and Image {node2.idx}")
    plt.axis('off')
    plt.show()

In [None]:
# counter = 0
# for node in image_nodes:
#     if counter == 3:
#         break
#     if node.footprint == ['ref']:
#         matches, src_pts, dst_pts = find_matches(node, image_node_ref, max_desc_dist)
#         valid_matches = [m for m in matches if m.queryIdx < len(node.keypoints) and m.trainIdx < len(ref_keypoints)]
#         show_keypoints_matches(node, image_node_ref, max_desc_dist)
#         # show_keypoints_matches(node, image_nodes[node.footprint[0]])
#         counter += 1

In [None]:
from plot import visualize_homography_alignment_single

counter = 0
# Apply the function to all image nodes
for node in image_nodes:
    if counter == 3:
        break
    if node.footprint == ['ref']:
        visualize_homography_alignment_single(ref_img, node)
        counter += 1


In [None]:
from plot import create_canvas_with_images

# Example usage:
final_canvas = create_canvas_with_images(ref_img, image_nodes)

# Display the final canvas
plt.figure(figsize=(20, 20))
plt.imshow(final_canvas)
plt.title("Final Canvas with All Images in Reference Frame")
plt.axis('off')
plt.show()