# Align images using keypoints and homography

In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

### Load images

In [2]:
# %% load images
import os
import tifffile as tiff
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
# import ipywidgets as widgets
# from IPython.display import display

# Directory containing the images
directory = os.path.join('data', 'beads')

# List to hold image paths
image_paths = []

# Load the image paths
for filename in os.listdir(directory):
    if filename.endswith('.tif'):
        image_paths.append(os.path.join(directory, filename))

# Ensure there are images in the directory
if len(image_paths) == 0:
    raise ValueError("No TIFF images found in the directory.")

# Function to load the z-stack from a TIFF file
def load_z_stack(image_path):
    with tiff.TiffFile(image_path) as tif:
        z_stack = [page.asarray() for page in tif.pages]
    return z_stack

z_stacks = [load_z_stack(image_path) for image_path in image_paths]

## Functionality

### Round 1

In [3]:
def detect_and_compute_keypoints(image, detector):
    keypoints, descriptors = detector.detectAndCompute(image, None)
    return keypoints, descriptors


def match_descriptors(descriptors1, descriptors2, distance_metric=cv2.NORM_L2, cross_check=True):
    bf = cv2.BFMatcher(distance_metric, crossCheck=cross_check)
    matches = bf.match(descriptors1, descriptors2)
    matches = sorted(matches, key=lambda x: x.distance)
    return matches


def filter_matches_by_distance(matches, keypoints1, keypoints2, max_distance=100):
    filtered_matches = []
    for match in matches:
        pt1 = keypoints1[match.queryIdx].pt
        pt2 = keypoints2[match.trainIdx].pt
        distance = np.linalg.norm(np.array(pt1) - np.array(pt2))
        if distance < max_distance:
            filtered_matches.append(match)
    return filtered_matches


def compute_homography(filtered_matches, keypoints1, keypoints2):
    if len(filtered_matches) >= 4:
        src_pts = np.float32([keypoints1[m.queryIdx].pt for m in filtered_matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([keypoints2[m.trainIdx].pt for m in filtered_matches]).reshape(-1, 1, 2)
        H, mask = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 5.0)
        return H, mask
    else:
        print("Not enough matches are found to compute a reliable homography.")
        return None, None


def warp_image(image, H, shape):
    return cv2.warpPerspective(image, H, shape)


def visualize_keypoints_and_matches(image1, keypoints1, image2, keypoints2, matches, title):
    img_matches = cv2.drawMatches(image1, keypoints1, image2, keypoints2, matches, None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
    plt.figure(figsize=(15, 5))
    plt.title(title)
    plt.imshow(img_matches)
    plt.axis('off')
    plt.show()


def visualize_warped_image(image1, image2, warped_image):
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.title('Plane 1 Slice 1')
    plt.imshow(image1, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title('Plane 2 Slice 1')
    plt.imshow(image2, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title('Warped Plane 2 Slice 1')
    plt.imshow(warped_image, cmap='gray')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


def overlay_and_difference_images(image1, warped_image):
    overlay = cv2.addWeighted(image1, 0.5, warped_image, 0.5, 0)
    difference = cv2.absdiff(image1, warped_image)

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.title('Overlay of Plane 1 and Warped Plane 2')
    plt.imshow(overlay, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title('Difference between Plane 1 and Warped Plane 2')
    plt.imshow(difference, cmap='gray')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


def calculate_keypoint_displacement(filtered_matches, keypoints1, keypoints2, H):
    src_pts = np.float32([keypoints1[m.queryIdx].pt for m in filtered_matches]).reshape(-1, 2)
    dst_pts = np.float32([keypoints2[m.trainIdx].pt for m in filtered_matches]).reshape(-1, 2)
    warped_keypoints = cv2.perspectiveTransform(dst_pts.reshape(-1, 1, 2), H).reshape(-1, 2)
    displacement = np.linalg.norm(warped_keypoints - src_pts, axis=1)

    print('Displacement of keypoints after warping:')
    for i, d in enumerate(displacement):
        print(f'Keypoint {i}: Displacement = {d:.2f} pixels')

    plt.figure(figsize=(8, 8))
    plt.scatter(src_pts[:, 0], src_pts[:, 1], color='red', label='Original Keypoints')
    plt.scatter(warped_keypoints[:, 0], warped_keypoints[:, 1], color='blue', label='Warped Keypoints', alpha=0.6)
    plt.legend()
    plt.title('Keypoint Displacement due to Warping')
    plt.gca().invert_yaxis()
    plt.show()


def main(plane1_slice1, plane2_slice1):
    # Initialize SIFT detector
    sift = cv2.SIFT_create()

    # Detect keypoints and compute descriptors
    keypoints1, descriptors1 = detect_and_compute_keypoints(plane1_slice1, sift)
    keypoints2, descriptors2 = detect_and_compute_keypoints(plane2_slice1, sift)

    # Match descriptors
    matches = match_descriptors(descriptors1, descriptors2)

    # # Visualize initial matches
    # visualize_keypoints_and_matches(plane1_slice1, keypoints1, plane2_slice1, keypoints2, matches, 'Matches between Plane 1 Slice 1 and Plane 2 Slice 1 using SIFT')

    # Filter matches by distance
    filtered_matches = filter_matches_by_distance(matches, keypoints1, keypoints2, max_distance=100)

    print(f'Filtered matches within distance threshold: {len(filtered_matches)}')

    # Visualize filtered matches
    visualize_keypoints_and_matches(plane1_slice1, keypoints1, plane2_slice1, keypoints2, filtered_matches, 'Filtered Matches between Plane 1 and Plane 2 (with Distance Filter)')

    # Compute homography and warp the second image
    H, _ = compute_homography(filtered_matches, keypoints1, keypoints2)
    if H is not None:
        height, width = plane1_slice1.shape
        warped_plane2 = warp_image(plane2_slice1, H, (width, height))

        # Visualize the aligned images
        visualize_warped_image(plane1_slice1, plane2_slice1, warped_plane2)

        # Visualize overlay and difference
        overlay_and_difference_images(plane1_slice1, warped_plane2)

        # Calculate and visualize keypoint displacement
        calculate_keypoint_displacement(filtered_matches, keypoints1, keypoints2, H)

In [None]:
def normalize_and_convert_to_uint8(image1, image2):
    """
    Normalize images to the range [0, 255] and convert them to uint8.

    Args:
        image1 (numpy.ndarray): The first image to process.
        image2 (numpy.ndarray): The second image to process.

    Returns:
        tuple: Two images, normalized and converted to uint8.
    """
    # Normalize the images to the range [0, 255]
    image1_normalized = cv2.normalize(image1, None, 0, 255, cv2.NORM_MINMAX)
    image2_normalized = cv2.normalize(image2, None, 0, 255, cv2.NORM_MINMAX)

    # Convert the images to uint8
    image1_uint8 = image1_normalized.astype(np.uint8)
    image2_uint8 = image2_normalized.astype(np.uint8)

    return image1_uint8, image2_uint8


In [4]:
def normalize_and_convert_to_uint8_single(image):
    """
    Normalize an image to the range [0, 255] and convert it to uint8.

    Args:
        image (numpy.ndarray): The image to process.

    Returns:
        numpy.ndarray: The image, normalized and converted to uint8.
    """
    # Normalize the image to the range [0, 255]
    image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX)

    # Convert the image to uint8
    image_uint8 = image_normalized.astype(np.uint8)

    return image_uint8

In [None]:
plane1_slice1 = z_stacks[0][0]
plane2_slice1 = z_stacks[1][0]
plane1_slice1, plane2_slice1 = normalize_and_convert_to_uint8(plane1_slice1, plane2_slice1)

# Run the main function
main(plane1_slice1, plane2_slice1)


### Round 2


In [None]:
def main(plane1_slice1, plane2_slice1):
    # Initialize SIFT detector
    sift = cv2.SIFT_create()

    # Detect keypoints and compute descriptors
    keypoints1, descriptors1 = detect_and_compute_keypoints(plane1_slice1, sift)
    keypoints2, descriptors2 = detect_and_compute_keypoints(plane2_slice1, sift)

    # Match descriptors
    matches = match_descriptors(descriptors1, descriptors2)

    # Filter matches by distance
    filtered_matches = filter_matches_by_distance(matches, keypoints1, keypoints2, max_distance=100)

    print(f'Filtered matches within distance threshold: {len(filtered_matches)}')

    # Visualize filtered matches
    visualize_keypoints_and_matches(plane1_slice1, keypoints1, plane2_slice1, keypoints2, filtered_matches, 'Filtered Matches between Plane 1 and Plane 2 (with Distance Filter)')

    # Compute homography and warp the second image
    H, _ = compute_homography(filtered_matches, keypoints1, keypoints2)
    if H is not None:
        height, width = plane1_slice1.shape
        warped_plane2 = warp_image(plane2_slice1, H, (width, height))

        # Visualize the aligned images
        visualize_warped_image(plane1_slice1, plane2_slice1, warped_plane2)

        # Visualize overlay and difference
        overlay_and_difference_images(plane1_slice1, warped_plane2)

        # Calculate and visualize keypoint displacement
        calculate_keypoint_displacement(filtered_matches, keypoints1, keypoints2, H)

In [None]:
main(plane1_slice1, plane2_slice1)


In [5]:
def align_and_return_transform(plane1_slice1, plane2_slice1):
    """
    Process the images to detect keypoints, match descriptors, filter matches, 
    compute homography, and warp the second image.

    Parameters:
    plane1_slice1 (numpy.ndarray): The first image.
    plane2_slice1 (numpy.ndarray): The second image.

    Returns:
    tuple: (filtered_matches, keypoints1, keypoints2, H, warped_plane2)
        - filtered_matches: The list of filtered matches.
        - keypoints1: Keypoints detected in the first image.
        - keypoints2: Keypoints detected in the second image.
        - H: The computed homography matrix.
        - warped_plane2: The second image after warping with the homography.
    """
    # Initialize SIFT detector
    sift = cv2.SIFT_create()

    # Detect keypoints and compute descriptors
    keypoints1, descriptors1 = detect_and_compute_keypoints(plane1_slice1, sift)
    keypoints2, descriptors2 = detect_and_compute_keypoints(plane2_slice1, sift)

    # Match descriptors
    matches = match_descriptors(descriptors1, descriptors2)

    # Filter matches by distance
    filtered_matches = filter_matches_by_distance(matches, keypoints1, keypoints2, max_distance=10)

    print(f'Filtered matches within distance threshold: {len(filtered_matches)}')

    # Compute homography
    H, _ = compute_homography(filtered_matches, keypoints1, keypoints2)

    # Warp the second image if homography is computed
    warped_plane2 = None
    if H is not None:
        height, width = plane1_slice1.shape
        warped_plane2 = warp_image(plane2_slice1, H, (width, height))

    return warped_plane2, filtered_matches, keypoints1, keypoints2, H

In [6]:
warped_plane2, filtered_matches, keypoints1, keypoints2, H = align_and_return_transform(plane1_slice1, plane2_slice1)

NameError: name 'plane1_slice1' is not defined

In [None]:
def align_stacks(plane1_stack, plane2_stack, plane3_stack):
    """Align the 1st and 3rd planes of image stacks to the 2nd plane and return alignment info."""
    aligned_plane1_stack = []
    aligned_plane3_stack = []

    all_info_plane1 = []
    all_info_plane3 = []

    for i in range(len(plane2_stack)):
        # Normalize and convert each slice to uint8
        plane1_slice = normalize_and_convert_to_uint8_single(plane1_stack[i])
        plane2_slice = normalize_and_convert_to_uint8_single(plane2_stack[i])
        plane3_slice = normalize_and_convert_to_uint8_single(plane3_stack[i])

        # Align the 1st slice to the 2nd
        aligned_plane1_slice, filtered_matches1, keypoints1_1, keypoints2_1, H1 = align_and_return_transform(plane2_slice, plane1_slice)
        aligned_plane3_slice, filtered_matches3, keypoints1_3, keypoints2_3, H3 = align_and_return_transform(plane2_slice, plane3_slice)

        if aligned_plane1_slice is not None:
            aligned_plane1_stack.append(aligned_plane1_slice)
        else:
            print(f"Warning: Alignment failed for slice {i} in Plane 1.")
        
        if aligned_plane3_slice is not None:
            aligned_plane3_stack.append(aligned_plane3_slice)
        else:
            print(f"Warning: Alignment failed for slice {i} in Plane 3.")

        # Store alignment information
        all_info_plane1.append((filtered_matches1, keypoints1_1, keypoints2_1, H1))
        all_info_plane3.append((filtered_matches3, keypoints1_3, keypoints2_3, H3))
        
        # break

    return (np.array(aligned_plane1_stack), plane2_stack, np.array(aligned_plane3_stack)), all_info_plane1, all_info_plane3

In [7]:
def align_single_plane(plane_slice, ref_slice):
    """Align a single plane slice to a reference slice and return alignment info."""
    # Normalize and convert slices to uint8
    plane_slice = normalize_and_convert_to_uint8_single(plane_slice)
    ref_slice = normalize_and_convert_to_uint8_single(ref_slice)

    # Align the plane slice to the reference slice and return the transformation info
    aligned_slice, filtered_matches, keypoints1, keypoints2, H = align_and_return_transform(ref_slice, plane_slice)
    
    if aligned_slice is None:
        aligned_slice = plane_slice

    return aligned_slice, (filtered_matches, keypoints1, keypoints2, H)

def align_stacks(plane1_stack, plane2_stack, plane3_stack):
    """Align the 1st and 3rd planes of image stacks to the 2nd plane and return alignment info."""
    aligned_plane1_stack, aligned_plane3_stack = [], []
    all_info_plane1, all_info_plane3 = [], []

    for i in range(len(plane2_stack)):
        # Align slices for Plane 1 and Plane 3 to Plane 2
        aligned_plane1_slice, info_plane1 = align_single_plane(plane1_stack[i], plane2_stack[i])
        aligned_plane3_slice, info_plane3 = align_single_plane(plane3_stack[i], plane2_stack[i])

        # Store the aligned slices if alignment succeeded
        if aligned_plane1_slice is not None:
            aligned_plane1_stack.append(aligned_plane1_slice)
        else:
            print(f"Warning: Alignment failed for slice {i} in Plane 1.")
        
        if aligned_plane3_slice is not None:
            aligned_plane3_stack.append(aligned_plane3_slice)
        else:
            print(f"Warning: Alignment failed for slice {i} in Plane 3.")

        # Store alignment information
        all_info_plane1.append(info_plane1)
        all_info_plane3.append(info_plane3)

    return (
        (np.array(aligned_plane1_stack), plane2_stack, np.array(aligned_plane3_stack)),
        all_info_plane1,
        all_info_plane3
    )

In [8]:
plane1_stack = z_stacks[0]
plane2_stack = z_stacks[1]
plane3_stack = z_stacks[2]

(aligned_plane1_stack, aligned_plane2_stack, aligned_plane3_stack), info_plane1, info_plane3 = align_stacks(plane1_stack, plane2_stack, plane3_stack)

Filtered matches within distance threshold: 7
Filtered matches within distance threshold: 6
Filtered matches within distance threshold: 7
Filtered matches within distance threshold: 7
Filtered matches within distance threshold: 9
Filtered matches within distance threshold: 5
Filtered matches within distance threshold: 13
Filtered matches within distance threshold: 4
Filtered matches within distance threshold: 12
Filtered matches within distance threshold: 4
Filtered matches within distance threshold: 12
Filtered matches within distance threshold: 6
Filtered matches within distance threshold: 13
Filtered matches within distance threshold: 5
Filtered matches within distance threshold: 16
Filtered matches within distance threshold: 7
Filtered matches within distance threshold: 14
Filtered matches within distance threshold: 6
Filtered matches within distance threshold: 20
Filtered matches within distance threshold: 8
Filtered matches within distance threshold: 20
Filtered matches within di

In [None]:
plane2_stack[0].shape

In [None]:
len(info_plane1)
len(info_plane3[0])


### Display aligned images

In [9]:
def display_slices_with_slider(z_stacks, image_paths):
    """
    Display slices from multiple image stacks side by side with a slider to navigate through slices.

    Parameters:
    z_stacks (list of np.ndarray): List of image stacks, where each stack is a 3D NumPy array.
    image_paths (list of str): List of file paths or names for the image stacks, used for labeling.

    Usage:
    display_slices_with_slider(z_stacks, image_paths)
    """
    def show_slices(slice_num):
        fig, axes = plt.subplots(1, len(z_stacks), figsize=(20, 8))
        for ax, z_stack, image_path in zip(axes, z_stacks, image_paths):
            ax.imshow(z_stack[slice_num], cmap='gray')
            ax.set_title(f"{os.path.basename(image_path)} - Slice {slice_num + 1}")
            ax.axis('off')
        plt.show()

    # Determine the number of slices (assuming all stacks have the same number of slices)
    num_slices = len(z_stacks[0])

    # Create a slider
    slider = IntSlider(min=0, max=num_slices - 1, step=1, value=0, description='Slice')

    # Display the slices side by side with the slider
    interact(show_slices, slice_num=slider)

z_stacks = [aligned_plane1_stack, aligned_plane2_stack, aligned_plane3_stack]
display_slices_with_slider(z_stacks, image_paths)

interactive(children=(IntSlider(value=0, description='Slice', max=60), Output()), _dom_classes=('widget-intera…

## Align to offset slice

In [13]:
def align_stacks_to_previous_slices(plane1_stack, plane2_stack, plane3_stack, offset=7):
    """Align each slice to the slice `offset` positions above it in the stack."""
    aligned_plane1_stack, aligned_plane3_stack = [], []
    all_info_plane1, all_info_plane3 = [], []

    for i in range(len(plane2_stack)):
        if i >= offset and i < len(plane2_stack) - offset - 1:
            # Align the current slice to the slice `offset` positions below and above it
            aligned_plane1_slice, info_plane1 = align_single_plane(plane1_stack[i - offset], plane2_stack[i])
            aligned_plane3_slice, info_plane3 = align_single_plane(plane3_stack[i + offset], plane2_stack[i])

            # Store the aligned slices (or original slices if no alignment was possible)
            aligned_plane1_stack.append(aligned_plane1_slice)
            aligned_plane3_stack.append(aligned_plane3_slice)

            # Store alignment information
            all_info_plane1.append(info_plane1)
            all_info_plane3.append(info_plane3)

    return (
        (np.array(aligned_plane1_stack), plane2_stack, np.array(aligned_plane3_stack)),
        all_info_plane1,
        all_info_plane3
    )

In [None]:
len(plane2_stack)

In [14]:
(aligned_plane1_stack, aligned_plane2_stack, aligned_plane3_stack), info_plane1, info_plane3 = align_stacks_to_previous_slices(plane1_stack, plane2_stack, plane3_stack)

Filtered matches within distance threshold: 27
Filtered matches within distance threshold: 24
Filtered matches within distance threshold: 29
Filtered matches within distance threshold: 25
Filtered matches within distance threshold: 38
Filtered matches within distance threshold: 35
Filtered matches within distance threshold: 41
Filtered matches within distance threshold: 38
Filtered matches within distance threshold: 48
Filtered matches within distance threshold: 43
Filtered matches within distance threshold: 49
Filtered matches within distance threshold: 43
Filtered matches within distance threshold: 58
Filtered matches within distance threshold: 50
Filtered matches within distance threshold: 68
Filtered matches within distance threshold: 51
Filtered matches within distance threshold: 73
Filtered matches within distance threshold: 59
Filtered matches within distance threshold: 92
Filtered matches within distance threshold: 61
Filtered matches within distance threshold: 102
Filtered mat

In [15]:
z_stacks = [aligned_plane1_stack, aligned_plane2_stack, aligned_plane3_stack]
display_slices_with_slider(z_stacks, image_paths)

interactive(children=(IntSlider(value=0, description='Slice', max=45), Output()), _dom_classes=('widget-intera…