In [None]:
from registration_by_features import RegistrationByFeatures
from registration_by_intensity import RegistrationByIntensity

import matplotlib.pyplot as plt
import sys
import cv2

In [None]:
def rgb2gray(rgb):
    r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray

In [None]:
def display_result(img_BL, tf_img):
    # both images (overlaid)
    plt.figure(figsize=(10, 10))
    plt.title('BL + transformed FU', fontsize=25)
    plt.imshow(img_BL, "gray")
    plt.imshow(tf_img, alpha=0.4)
    plt.show()

    plt.figure(figsize=(10, 10))
    plt.title('BL', fontsize=25)
    plt.imshow(img_BL, "gray")
    plt.show()

    plt.figure(figsize=(10, 10))
    plt.title('transformed FU', fontsize=25)
    plt.imshow(tf_img, "gray")
    plt.show()



In [None]:
if __name__ == '__main__':

    img_BL_path = "../retinal_images/BL022.bmp"
    img_FU_path = "../retinal_images/FU022.bmp"

    img_BL = cv2.imread(img_BL_path, cv2.IMREAD_GRAYSCALE)
    img_FU = cv2.imread(img_FU_path, cv2.IMREAD_GRAYSCALE)


    regByFeatures = RegistrationByFeatures()
    transformed_FU = regByFeatures.doRegistration(img_BL, img_FU)
    display_result(img_BL, transformed_FU)


    regByIntensity = RegistrationByIntensity()
    transformed_FU = regByIntensity.doRegistration(img_BL, img_FU)
    display_result(img_BL, transformed_FU)


In [None]:
import numpy as np
from skimage.filters import threshold_isodata
from skimage import morphology
import cv2
from scipy import ndimage
from skimage.registration import phase_cross_correlation

class RegistrationByIntensity:


    def SegmentBloodVessel(self, img):
        # Contrast Enhancement 
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        cl_img = clahe.apply(img)

        # front image
        blur = cv2.GaussianBlur(cl_img, (57, 57), 0)
        front_img = np.clip(cl_img - blur, 0, 255)
        front_img = cv2.GaussianBlur(front_img, (3, 3), 0)
        thresh = threshold_isodata(front_img)

        segmentation = np.zeros(img.shape)
        low_values_flags = front_img < thresh
        high_values_flags = front_img >= thresh
        segmentation[low_values_flags] = 0
        segmentation[high_values_flags] = 1

        # Opening 
        kernel = np.ones((3, 3), np.uint8)
        opening = cv2.morphologyEx(segmentation, cv2.MORPH_OPEN, kernel)

        kernel = np.ones((5, 5), np.uint8)
        opening = cv2.morphologyEx(opening, cv2.MORPH_OPEN, kernel)

        kernel = np.ones((7, 7), np.uint8)
        opening = cv2.morphologyEx(opening, cv2.MORPH_OPEN, kernel)

        # Closing
        kernel = np.ones((17, 17), np.uint8)
        closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)

        # Remove extra objects
        cleaned = morphology.remove_small_objects(closing.astype(bool), min_size=200, connectivity=2)

        return cleaned

    def doRegistration(self, img_BL, img_FU):
        seg_BL = self.SegmentBloodVessel(img_BL)
        seg_FU = self.SegmentBloodVessel(img_FU)

        shifts_lists = []
        for angle in np.arange(-30, 30, 1): 
            # Rotate
            rotated_FU = ndimage.rotate(seg_FU, angle, reshape=False)

            #  phase_cross_correlation
            detected_shift = phase_cross_correlation(seg_BL, rotated_FU) 
            shift, error, diffphase = detected_shift
            shifts_lists.append((shift, error, diffphase, angle))

        # Find the min error
        optimal_translation_and_rotation = min(shifts_lists, key=lambda t: t[1])
        optimal_translation = np.flip(optimal_translation_and_rotation[0])
        optimal_angle = optimal_translation_and_rotation[3]
        print("optimal angle:", optimal_angle)


        # Apply translation & rotate to FU image
        T_mat = np.float32([[1, 0, optimal_translation[0]], [0, 1, optimal_translation[1]]])
        height, width = img_FU.shape[:2]
        img_translation = cv2.warpAffine(img_FU, T_mat, (width, height))
        rotated_translation = ndimage.rotate(img_translation, optimal_angle, reshape=False)

        return rotated_translation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from skimage import transform
import cv2
from cv2 import xfeatures2d

class RegistrationByFeatures:


    def findFeatures(self, img):
        
        # SIFT
        sift = cv2.xfeatures2d.SIFT_create()
        keypoints, descriptors = sift.detectAndCompute(img, None)

        return keypoints, descriptors

    def calcPointBasedReg(self, FUPoints, BLPoints):
        points_num = BLPoints.shape[0]

        # Compute the centroids of both point sets
        centroid_BL = BLPoints.sum(axis=0) / points_num
        centroid_FU = FUPoints.sum(axis=0) / points_num

        # Compute the centered vectors
        centered_BLPoints = BLPoints - centroid_BL
        centered_FUPoints = FUPoints - centroid_FU

        # Compute the d × d covariance matrix
        S = centered_FUPoints.T @ centered_BLPoints

        # Compute SVD of S
        U, sigma, Vt = np.linalg.svd(S)

        # Compute rotation matrix R = VGU^T
        V = Vt.T
        Ut = U.T
        det = np.linalg.det(V @ Ut)

        G = np.zeros((V.shape[1], Ut.shape[0]))
        diagonal = [1] * V.shape[1]
        diagonal[-1] = det
        np.fill_diagonal(G, diagonal)

        R = V @ G @ Ut

        # Compute translation vector
        t = centroid_BL - (R @ centroid_FU)

        rigidReg = np.vstack((R.T, t))
        rigidReg = np.concatenate((rigidReg, np.array([[0], [0], [1]])), axis=1)

        return rigidReg

    def ransac(self, x, y, funcFindF, funcDist, minPtNum, iterNum, thDist, thInlrRatio):

        ptNum = len(x)
        thInlr = round(thInlrRatio * ptNum)

        inlrNum = np.zeros([iterNum, 1])
        fLib = np.zeros(shape=(iterNum, 3, 3))

        for i in range(iterNum):
            permut = np.random.permutation(ptNum)
            sampleIdx = permut[range(minPtNum)]
            f1 = funcFindF(x[sampleIdx, :], y[sampleIdx, :])
            dist = funcDist(x, y, f1)
            b = dist <= thDist
            r = np.array(range(len(b)))
            inlier1 = r[b]
            inlrNum[i] = len(inlier1)

            if len(inlier1) < thInlr:
                continue

            fLib[i] = funcFindF(x[inlier1, :], y[inlier1, :])

        idx = inlrNum.tolist().index(max(inlrNum))
        f = fLib[idx]
        dist = funcDist(x, y, f);
        b = dist <= thDist
        r = np.array(range(len(b)))
        inlierIdx = r[b]

        return f, inlierIdx

    def cartesianToHomogeneous(self, Points):
        points_num = Points.shape[0]
        return np.concatenate((Points, np.ones((points_num, 1))), axis=1)

    def calcDist(self, FUPoints, BLPoints, registration_matrix):
        homogeneous_BLPoints = self.cartesianToHomogeneous(BLPoints)
        homogeneous_FUPoints = self.cartesianToHomogeneous(FUPoints)

        new_FUPoints = homogeneous_FUPoints @ registration_matrix

        distances = np.sqrt(((homogeneous_BLPoints - new_FUPoints) ** 2).sum(axis=1))

        return distances

    def calcRobustPointBasedReg(self, FUPoints, BLPoints):
        f, inlierIdx = self.ransac(x=FUPoints, y=BLPoints, funcFindF=self.calcPointBasedReg, funcDist=self.calcDist,
                                    minPtNum=4, iterNum=200, thDist=20, thInlrRatio=0.1) # 20, 0.1
        return f, inlierIdx

    def applyTransformation(self, img_FU, rigidReg):

        # Apply rigidReg.inverse to the FU image
        tform = transform.AffineTransform(rigidReg.T)  
        tf_img = transform.warp(img_FU, tform.inverse)

        return tf_img


    def plot_images(self, img_BL, img_FU, BLPoints, FUPoints, inliersIdx=None):

        f, axarr = plt.subplots(1, 2)

        # show images
        axarr[0].imshow(img_BL, "gray")
        axarr[1].imshow(img_FU, "gray")

        # set titles to subplots
        axarr[0].set_title('BaseLine')
        axarr[1].set_title('Follow Up')

        # plot points
        if inliersIdx is not None:
            points_num = BLPoints.shape[0]
            outliersIdx = [i for i in range(1, points_num) if i not in inliersIdx]
            BL_inliers = BLPoints[inliersIdx]
            BL_outliers = BLPoints[outliersIdx]
            FU_inliers = FUPoints[inliersIdx]
            FU_outliers = FUPoints[outliersIdx]

            l1 = axarr[0].scatter(BL_inliers[:, 0], BL_inliers[:, 1], color="red")
            axarr[1].scatter(FU_inliers[:, 0], FU_inliers[:, 1], color="red")
            l2 = axarr[0].scatter(BL_outliers[:, 0], BL_outliers[:, 1], color="blue")
            axarr[1].scatter(FU_outliers[:, 0], FU_outliers[:, 1], color="blue")

            plt.legend([l1, l2], ["inliers", "outliers"], bbox_to_anchor=(1.05, 1), borderaxespad=0.)

        else:
            axarr[0].scatter(BLPoints[:, 0], BLPoints[:, 1], color="red")
            axarr[1].scatter(FUPoints[:, 0], FUPoints[:, 1], color="red")

        # plot points names
        points_num = BLPoints.shape[0]
        for i in range(0, points_num):
            axarr[0].annotate(i + 1, BLPoints[i], color="orange")
            axarr[1].annotate(i + 1, FUPoints[i], color="orange")

        plt.show()

    def featuresMatching(self, kp1, des1, kp2, des2, img1, img2):
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = bf.match(des1, des2)
        matches = sorted(matches, key=lambda x: x.distance)

        BLPoints = []
        FUPoints = []
        for m in matches:
            BLPoints.append(list(kp1[m.queryIdx].pt))
            FUPoints.append(list(kp2[m.trainIdx].pt))

        BLPoints = np.array(BLPoints)
        FUPoints = np.array(FUPoints)

        # Display Matches
        img3 = np.zeros(img1.shape)
        draw_params = dict(matchColor=(0, 255, 0), singlePointColor=(255, 0, 0), flags=0)
        img3 = cv2.drawMatches(img1, kp1, img2, kp2, matches[:20], outImg=img3, **draw_params)
        plt.imshow(img3)
        plt.show()

        return BLPoints, FUPoints, matches

    def featuresMatchingKNN(self, kp1, des1, kp2, des2, img1, img2):

        bf = cv2.BFMatcher()
        matches = bf.knnMatch(des1, des2, k=2)

        good_matches = []
        for m, n in matches:
            if m.distance < 0.75 * n.distance:
                good_matches.append([m])

        img3 = np.zeros(img1.shape)
        draw_params = dict(matchColor=(0, 255, 0), singlePointColor=(255, 0, 0), flags=0)
        img3 = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good_matches,  outImg=img3, **draw_params)
        plt.imshow(img3)
        plt.show()

        BLPoints = []
        FUPoints = []
        for m in good_matches:
            BLPoints.append(list(kp1[m[0].queryIdx].pt))
            FUPoints.append(list(kp2[m[0].trainIdx].pt))

        BLPoints = np.array(BLPoints)
        FUPoints = np.array(FUPoints)

        return BLPoints, FUPoints, good_matches

    def doRegistration(self, img_BL, img_FU, ROBUST=True):
        # Features detecting
        kp1, des1 = self.findFeatures(img_BL)
        kp2, des2 = self.findFeatures(img_FU)

        # Features matching
        BLPoints, FUPoints, matches = self.featuresMatchingKNN(kp1, des1, kp2, des2, img_BL, img_FU)

        # Calculate registration matrix
        if ROBUST == True:
            registration_matrix, inlier_indx = self.calcRobustPointBasedReg(FUPoints, BLPoints)
            self.plot_images(self.__imgBL, self.__imgFU, BLPoints, FUPoints, inlier_indx)
        else:
            registration_matrix = self.calcPointBasedReg(FUPoints, BLPoints)

        # Registration
        tf_img = self.applyTransformation(img_FU, registration_matrix)

        return tf_img