Require Libraries

In [None]:
# LIBRARY IMPORTS AND INSTALLATION INSTRUCTIONS

# To ensure that the code can run properly, please install the following packages using pip or conda:

# !pip install opencv-python-headless numpy torch torchvision pillow scipy pandas matplotlib seaborn plotly scikit-image psutil scikit-learn

# Essential Libraries
import cv2
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from torchvision import transforms, models

# Statistical Analysis
from scipy.signal import correlate2d
import math
import pandas as pd

# Data Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px

# Error Handling
import warnings

# Performance Measurements
import time
from psutil import virtual_memory

# Image Analysis
from skimage import io
from skimage.metrics import normalized_root_mse as ncc, structural_similarity as ssim
from scipy.ndimage.filters import gaussian_gradient_magnitude

# Machine Learning
from sklearn.base import BaseEstimator

In [None]:
# This line of code is used to set the device for computation.
# If a CUDA-compatible GPU (Graphics Processing Unit) is available on the system,
# it will be set as the default device for tensor computations, which can significantly 
# speed up deep learning operations. If a CUDA-compatible GPU is not available, 
# the computations will default to the system's CPU (Central Processing Unit).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# This function defines the likelihood of a model given the data.
# This function uses a Gaussian (also known as Normal) distribution.
# The Gaussian distribution is parameterized by its mean and standard deviation (sigma).
# In this case, the 'error' variable represents the difference between the data and the model's predictions.
# The function returns a measure of the likelihood that the model generated the data.
# If the error is small, the function will return a large value (indicating high likelihood), and vice versa.
# 
# Input Parameters:
# - error: The difference between the observed data and the model's predictions.
# - sigma: The standard deviation of the Gaussian distribution. It's a measure of the variability in the data.
#
# Returns:
# - likelihood: A measure of how likely it is that the model generated the observed data.

def likelihood(error, sigma):
    return np.exp(-error**2 / (2 * sigma**2))


# This class defines a Homography Model which inherits from the BaseEstimator class provided by scikit-learn.
# Homography is a transformation that maps the points in one image to the corresponding points in the other image.
# A homography matrix is a 3x3 matrix that performs this transformation.
# The HomographyModel has two main methods:
# 1. fit: This method calculates the homography matrix between source points and destination points.
# 2. errors: This method calculates the reprojection error, which is the Euclidean distance between
#    the destination points and the source points transformed by the estimated homography matrix.
#
class HomographyModel(BaseEstimator):
    # The fit method estimates the homography matrix using RANSAC algorithm
    # 
    # Input Parameters:
    # - X: The source and destination points. The first half should be the source points, 
    #      and the second half should be the destination points.
    #
    # Returns:
    # - self: An instance of the class HomographyModel.
    def fit(self, X, y=None):
        src_pts = X[:, :2].reshape(-1, 1, 2)
        dst_pts = X[:, 2:].reshape(-1, 1, 2)
        self.H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC)
        return self

    # The errors method calculates the reprojection error, which is the Euclidean distance between
    # the destination points and the source points transformed by the estimated homography matrix.
    #
    # Input Parameters:
    # - X: The source and destination points. The first half should be the source points, 
    #      and the second half should be the destination points.
    #
    # Returns:
    # - A 1D numpy array containing the reprojection errors for each point.
    def errors(self, X, y=None):
        src_pts = X[:, :2].reshape(-1, 1, 2)
        dst_pts = X[:, 2:].reshape(-1, 1, 2)
        dst_pts_estimated = cv2.perspectiveTransform(src_pts, self.H)
        return np.sqrt(np.sum((dst_pts - dst_pts_estimated)**2, axis=2)).ravel()



# This function implements the MLESAC algorithm, which is a robust method for fitting a model
# to data that may contain outliers. MLESAC uses maximum likelihood estimation to evaluate the quality 
# of the fitted models, instead of the consensus set size used by RANSAC. It allows to find 
# the best model with the highest likelihood.
#
# The mlesac function:
# 1. Randomly selects a subset of the data and fits the model to this subset.
# 2. Calculates the errors of the data points to the fitted model.
# 3. Classifies the data points as inliers or outliers based on a threshold.
# 4. Calculates the likelihood of the inliers and the outliers.
# 5. If the total likelihood is higher than the best likelihood found so far, updates the best model, 
#    the best likelihood, and the best inliers.
# 6. Repeats the process for a specified number of iterations.
#
# Input Parameters:
# - data: The input data.
# - model_class: The class of the model to be fitted.
# - min_samples: The minimum number of data points required to fit the model.
# - max_iterations: The maximum number of iterations for the algorithm.
# - sigma: The standard deviation of the Gaussian that is used to estimate the likelihood.
# - threshold: The threshold used to classify the data points as inliers or outliers.
#
# Returns:
# - best_model: The best fitted model.
# - best_inliers: The inliers of the best fitted model.

def mlesac(data, model_class, min_samples, max_iterations, sigma, threshold):
    best_model = None
    best_likelihood = 0
    best_inliers = None

    for _ in range(max_iterations):
        # Select a random subset of the data
        sample = data[np.random.choice(data.shape[0], min_samples, replace=False)]
        # Fit the model to the subset
        model = model_class()
        model.fit(sample)

        # Calculate the errors of the data points to the fitted model
        errors = model.errors(data)
        # Classify the data points as inliers or outliers based on the threshold
        inliers = data[errors <= threshold]
        outliers = data[errors > threshold]

        # Calculate the likelihood of the inliers and the outliers
        likelihood_inliers = np.sum(likelihood(errors[errors <= threshold], sigma))
        likelihood_outliers = np.sum(likelihood(errors[errors > threshold], sigma))
        # Calculate the total likelihood
        likelihood_total = likelihood_inliers + likelihood_outliers

        # If the total likelihood is higher than the best likelihood found so far, update the best model, 
        # the best likelihood, and the best inliers
        if likelihood_total > best_likelihood:
            best_model = model
            best_likelihood = likelihood_total
            best_inliers = inliers

    # Return the best model and the inliers of the best model
    return best_model, best_inliers


# The draw_keypoint_matches function visualizes the matching keypoints between two images.
# It takes as input two images, their keypoints, the matches between the keypoints, an adaptive threshold, 
# and a path to save the resulting image. The function draws only the matches that have a distance 
# less than or equal to the adaptive threshold.
# 
# Input Parameters:
# - img1: The first image.
# - img2: The second image.
# - keypoints1: The keypoints detected in the first image.
# - keypoints2: The keypoints detected in the second image.
# - matches: The matches between the keypoints of the two images.
# - adaptive_threshold: The threshold used to filter the matches. Only matches with a distance 
#   less than or equal to this threshold are drawn.
# - save_path: The path where the resulting image is saved.
# 
# Returns:
# - None. The function saves the resulting image to the specified path.

def draw_keypoint_matches(img1, img2, keypoints1, keypoints2, matches, adaptive_threshold, save_path):
    # Filter the matches based on the adaptive threshold
    good_matches = [m for m in matches if m.distance <= adaptive_threshold]
    # Draw the matches on the images
    img_matches = cv2.drawMatches(img1, keypoints1, img2, keypoints2, good_matches, None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
    # Save the image with drawn matches
    cv2.imwrite(save_path, img_matches)


# The adjust_brightness function adjusts the brightness of an image.
# It takes as input an image and optionally the scale factor (alpha) and an added constant (beta).
# 
# Input Parameters:
# - img: The input image.
# - alpha: The scale factor that is multiplied with each pixel. Default is 0.0.
# - beta: The constant added to each pixel. Default is 0.0.
# 
# Returns:
# - Adjusted image with the new brightness.
def adjust_brightness(img, alpha=0.0, beta=0.0):
    return cv2.convertScaleAbs(img, alpha=alpha, beta=beta)


# The adjust_contrast function adjusts the contrast of an image.
# It takes as input an image and optionally the scale factor (alpha) and an added constant (beta).
# 
# Input Parameters:
# - img: The input image.
# - alpha: The scale factor that is multiplied with each pixel. Default is 0.0.
# - beta: The constant added to each pixel. Default is 0.0.
# 
# Returns:
# - Adjusted image with the new contrast.
def adjust_contrast(img, alpha=0.0, beta=0.0):
    return cv2.convertScaleAbs(img, alpha=alpha, beta=beta)


# The denoise_image function denoises an image using the Non-local Means Denoising algorithm.
# It takes as input an image and optionally a weight for the denoising process.
# 
# Input Parameters:
# - img: The input image.
# - weight: The weight for the denoising process. Default is 0.1.
# 
# Returns:
# - Denoised image.
def denoise_image(img, weight=0.1):
    return cv2.fastNlMeansDenoising(img, None, weight)


# The preprocess_image function applies a series of preprocessing steps to an image.
# It first adjusts the brightness, then the contrast, and finally denoises the image.
# 
# Input Parameters:
# - img: The input image.
# 
# Returns:
# - The preprocessed image.
def preprocess_image(img):
    img = adjust_brightness(img, alpha=1.0, beta=0.0)  # Adjust brightness
    img = adjust_contrast(img, alpha=1.0, beta=0.0)    # Adjust contrast
    img = denoise_image(img, weight=0.0)               # Denoise image
    return img


# The open_images function opens a fixed image and a list of moving images using their file paths.
# It applies preprocessing steps to all images using the preprocess_image function.
#
# Input Parameters:
# - fixed_image_path: The file path of the fixed image.
# - moving_image_paths: A list of file paths for the moving images.
#
# Returns:
# - The preprocessed fixed image and a list of preprocessed moving images.
def open_images(fixed_image_path, moving_image_paths):
    # Read the fixed image from the file path
    fixed_image = cv2.imread(fixed_image_path)
    
    # Read the moving images from the file paths
    moving_images = [cv2.imread(path) for path in moving_image_paths]

    # Apply preprocessing to the fixed image
    fixed_image = preprocess_image(fixed_image)
    
    # Apply preprocessing to the moving images
    moving_images = [preprocess_image(img) for img in moving_images]

    return fixed_image, moving_images


# The preprocess_image_for_dl function processes an image for input into a deep learning model.
# It converts the image to RGB, applies a series of transforms, and converts the image to a tensor.
#
# Input Parameters:
# - img: The input image.
#
# Returns:
# - The preprocessed image tensor.
def preprocess_image_for_dl(img):
    # Convert the input image to a PIL Image object
    img = Image.fromarray(img)
    
    # Convert the image to RGB format
    img = img.convert('RGB')
    
    # Define a series of transforms:
    # Resize and center crop the image to 224x224 pixels,
    # Convert the image to a tensor,
    # Normalize the image with mean and standard deviation for each color channel.
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Apply the defined transforms to the image
    img_tensor = preprocess(img)

    return img_tensor


# The features_to_keypoints_and_descriptors function takes an array of features as input,
# reshapes it into a 2D image, and then uses the SIFT feature detection algorithm
# to extract keypoints and descriptors from the image.
#
# Input Parameters:
# - features_np: A numpy array of features.
#
# Returns:
# - keypoints: The detected keypoints.
# - descriptors: The corresponding descriptors for the detected keypoints.
def features_to_keypoints_and_descriptors(features_np):
    # Reshape the input array into a 2D image.
    # The chosen dimensions are the square root of the size of the input array,
    # ensuring that the total number of elements remains the same.
    h, w = int(np.sqrt(features_np.size)), int(np.sqrt(features_np.size))
    img = cv2.resize(features_np, (h, w))

    # Convert the image to 8-bit unsigned integers.
    # This is done for compatibility with the SIFT feature detector, which requires 8-bit images.
    img = (img * 255).astype(np.uint8)

    # Initialize the SIFT feature detector.
    sift = cv2.SIFT_create()

    # Use the SIFT feature detector to detect keypoints in the image,
    # and compute the corresponding descriptors.
    keypoints, descriptors = sift.detectAndCompute(img, None)

    return keypoints, descriptors



# The extract_features function uses both a deep learning model (ResNet) and the SIFT algorithm
# to extract features from a given image.
#
# Input Parameters:
# - img: The image from which features will be extracted.
# - model: The deep learning model (ResNet) used for feature extraction.
#
# Returns:
# - kp: The detected keypoints.
# - des: The corresponding descriptors for the detected keypoints.
def extract_features(img, model):
    # Preprocess the image for deep learning, add an extra dimension to the tensor,
    # and move it to the device where the computations will be performed.
    img_tensor = preprocess_image_for_dl(img).unsqueeze(0).to(device)
    
    # Use the ResNet model to extract features from the image.
    features = model(img_tensor)
    
    # Remove the extra dimension from the features tensor, move it back to the CPU,
    # and convert it into a numpy array.
    features_np = features.squeeze().cpu().detach().numpy()
    
    # Convert the numpy array of features into float32 type for compatibility with the SIFT function.
    features_np = features_np.astype(np.float32)
    
    # Use the SIFT function to detect keypoints and compute descriptors based on the extracted features.
    kp, des = features_to_keypoints_and_descriptors(features_np)
    
    return kp, des


# The extract_combined_features function uses both a deep learning model (ResNet) and the SIFT algorithm
# to extract features from a given image, effectively combining both methods.
#
# Input Parameters:
# - image: The image from which features will be extracted.
# - model: The deep learning model (ResNet) used for feature extraction.
#
# Returns:
# - resnet_keypoints: The keypoints detected by the ResNet model.
# - resnet_descriptors: The corresponding descriptors for the ResNet detected keypoints.
# - sift_keypoints: The keypoints detected by the SIFT algorithm.
# - sift_descriptors: The corresponding descriptors for the SIFT detected keypoints.
def extract_combined_features(image, model):
    # Extract features using the deep learning model (ResNet).
    resnet_keypoints, resnet_descriptors = extract_features(image, model)
    
    # Initialize the SIFT feature detector.
    sift = cv2.SIFT_create()
    
    # Use the SIFT feature detector to detect keypoints and compute descriptors.
    sift_keypoints, sift_descriptors = sift.detectAndCompute(image, None)
    
    return resnet_keypoints, resnet_descriptors, sift_keypoints, sift_descriptors



# The homography_registration function performs image registration by finding a homography (a transformation that maps points in one image to another) 
# using combined features from the deep learning model (ResNet) and the SIFT algorithm.
#
# Input Parameters:
# - img1, img2: The pair of images to be registered.
# - model: The deep learning model (ResNet) used for feature extraction.
# - method: The method used to compute the homography (default is RANSAC).
# - threshold: The distance threshold to identify inliers in the RANSAC algorithm.
# - threshold_multiplier: A factor to adjust the adaptive threshold. 
#
# Returns:
# - H: The homography matrix.
# - matches: The matched keypoints from both images.
# - adaptive_threshold: The adaptive threshold used to select good matches.
# - keypoints1, keypoints2: The keypoints detected in each image.
def homography_registration(img1, img2, model, method=cv2.RANSAC, threshold=5.0, threshold_multiplier=5.0):
    # Extract combined features from both images.
    resnet_keypoints1, resnet_descriptors1, keypoints1, descriptors1 = extract_combined_features(img1, model)
    resnet_keypoints2, resnet_descriptors2, keypoints2, descriptors2 = extract_combined_features(img2, model)

    # Initialize the FLANN-based matcher for SIFT features.
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    
    # Match the SIFT features from both images.
    sift_matches = flann.knnMatch(descriptors1, descriptors2, k=2)

    # Apply ratio test to find good matches.
    good_sift_matches = []
    for m, n in sift_matches:
        if m.distance < 0.7 * n.distance:
            good_sift_matches.append(m)

    # Initialize a brute-force matcher for the ResNet features.
    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
    
    # Match the ResNet features from both images.
    resnet_matches = bf.match(resnet_descriptors1, resnet_descriptors2)

    # Combine the matches from both ResNet and SIFT.
    matches = good_sift_matches + list(resnet_matches)
    
    # Sort the matches based on their distance.
    matches = sorted(matches, key=lambda x: x.distance)

    # Calculate the median distance among the matches and use it to set an adaptive threshold.
    median_distance = np.median([m.distance for m in matches])
    adaptive_threshold = median_distance * threshold_multiplier

    # Prepare the data for the MLESAC algorithm.
    src_pts = np.float32([keypoints1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
    dst_pts = np.float32([keypoints2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
    data = np.hstack([src_pts.reshape(-1, 2), dst_pts.reshape(-1, 2)])

    # Compute the homography using the MLESAC algorithm.
    model, inliers = mlesac(data, HomographyModel, min_samples=4, max_iterations=1000, sigma=1.0, threshold=adaptive_threshold)
    H = model.H

    print(f"Number of key points found: {len(matches)}")
    print(f"Number of key points found: {len(matches)}")
    print(f"Adaptive threshold used: {adaptive_threshold}")

    # Return the computed homography matrix, the matched keypoints, the adaptive threshold used and the detected keypoints in each image.
    return H, matches, adaptive_threshold, keypoints1, keypoints2


if __name__ == '__main__':
    # Define the paths to the fixed image and the moving images
    fixed_image_path = 'replace with your fixed image path'
    moving_image_paths = ['replace with your moving image(s) path']

    # Load and preprocess the images
    fixed_image, moving_images = open_images(fixed_image_path, moving_image_paths)

    # Define the shape of the image and the threshold multiplier
    height, width = fixed_image.shape[:2]
    threshold_multiplier = 1

    # Load the pre-trained ResNet50 model, and remove its last layer to get the features
    model = models.resnet50(pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1]))

    # Iterate over each moving image and register it with the fixed image
    for idx, img2 in enumerate(moving_images):

        # Start timer to calculate registration time
        start_time = time.time()

        # Calculate the homography registration, the adaptive threshold, and get the keypoints for each image
        H, matches, adaptive_threshold, keypoints1, keypoints2 = homography_registration(fixed_image, img2, model, threshold_multiplier=threshold_multiplier)
        
        # Apply the computed homography transformation to the moving image
        result = cv2.warpPerspective(img2, H, (width, height), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)

        # Stop timer and calculate registration time
        end_time = time.time()

        # Calculate number of inliers for adaptive threshold
        inliers = np.sum([1 for m in matches if m.distance <= adaptive_threshold])
        Registration_time = end_time - start_time

        # Print the results
        print("Registration_time (s): ", Registration_time)
        print(f"Image {idx + 1}: Number of inliers with adaptive threshold: {inliers}")

        # Save the result image
        cv2.imwrite(f'replace with your result image path', result)
        
        # Calculate and print number of inliers for fixed threshold
        fixed_threshold = 20
        inliers = np.sum([1 for m in matches if m.distance <= fixed_threshold])
        print(f"Image {idx + 1}: Number of inliers without adaptive threshold: {inliers}")

        # Save an image with the matched keypoints
        draw_keypoint_matches(fixed_image, img2, keypoints1, keypoints2, matches, adaptive_threshold, f'replace with your image path')

