In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import sys

# read the image file & output the color & gray image
def read_img(path):
    # opencv read image in BGR color space
    img = cv2.imread(path)
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return img, img_gray

# the dtype of img must be "uint8" to avoid the error of SIFT detector
def img_to_gray(img):
    if img.dtype != "uint8":
        print("The input image dtype is not uint8 , image type is : ",img.dtype)
        return
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return img_gray

# create a window to show the image
# It will show all the windows after you call im_show()
# Remember to call im_show() in the end of main
def creat_im_window(window_name,img):
    cv2.imshow(window_name,img)

# show the all window you call before im_show()
# and press any key to close all windows
def im_show():
    cv2.waitKey(0)
    cv2.destroyAllWindows()

#if __name__ == '__main__':
    # the example of image window
    # creat_im_window("Result",img)
    # im_show()

    # you can use this function to store the result
    # cv2.imwrite("result.jpg",img)

In [2]:
def SIFT_get_keypoint_and_descriptors(img):
    SIFT_Detector = cv2.SIFT_create()
    keypoints, descriptors = SIFT_Detector.detectAndCompute(img, None)
    #print(keypoints[0].pt[1])
    return keypoints, descriptors

In [3]:
def solve_homography(p1, p2): # p1 = origin coordinates p2 = target corrdinates
    A = []
    for i in range(len(p1)):
        A.append([-p1[i,0], -p1[i,1], -1, 0, 0, 0, p1[i,0]*p2[i,0], p1[i,1]*p2[i,0], p2[i,0]])
        A.append([0, 0, 0, -p1[i,0], -p1[i,1], -1, p1[i,0]*p2[i,1], p1[i,1]*p2[i,1], p2[i,1]])

    u, s, vh = np.linalg.svd(A)
    H = np.reshape(vh[8], (3,3)) #h33
    H = (1/H.item(8)) * H
    return H

In [4]:
class image_stitching:
    def __init__(self, path, num_image):
        self.path = path
        self.num_image = num_image
        self.image = []
        self.gray_image = []
        self.kps_desc = [] # keypoint_and_descriptors
        self.ratio = 0.75
        self.goodmatch_pos = None
        self.H = None
        self.final_result = None
        self.smoothing_window_size=800
        
    def get_image(self):
        for i in range(1,self.num_image+1):
            print(f"{self.path}/m{i}.jpg")
            image, gray_image = read_img(f"{self.path}/m{i}.jpg")
            self.image.append(image)
            self.gray_image.append(gray_image)
    
    def detecting_key_point(self):
        for i in range(self.num_image):
            keypoints, descriptors = SIFT_get_keypoint_and_descriptors(self.image[i])
           # print(keypoints)
            #print(descriptors)
            self.kps_desc.append([keypoints, descriptors])

    def feature_matching(self, kps_desc_left, kps_desc_right):# KNN
        keypoints_left, descriptors_left = kps_desc_left[0], kps_desc_left[1]
        keypoints_right, descriptors_right = kps_desc_right[0], kps_desc_right[1]
        match_idx_and_dis = []

        # K-nearest neighbor
        for i in range(len(descriptors_left)):
            first_min_idx_dist = [-1, np.inf]
            second_min_idx_dist = [-1, np.inf]
            for j in range(len(descriptors_right)):
                dist = np.linalg.norm(descriptors_left[i] - descriptors_right[j])
                if( dist < first_min_idx_dist[1]):
                    second_min_idx_dist = first_min_idx_dist
                    first_min_idx_dist = [j, dist]
                elif(dist < second_min_idx_dist[1] and dist != first_min_idx_dist[1]):
                    second_min_idx_dist = [j, dist]
            # store right image idx
            match_idx_and_dis.append([first_min_idx_dist[0], first_min_idx_dist[1],
                                     second_min_idx_dist[0], second_min_idx_dist[1]])
            

        # Lowe's ratio test
        goodmatch = []
        for i in range(len(match_idx_and_dis)):
            if(match_idx_and_dis[i][1] <= match_idx_and_dis[i][3] * self.ratio):
                goodmatch.append( (i, match_idx_and_dis[i][0]) )

        goodmatch_pos = []
        for (left_idx, right_idx) in goodmatch:
            image_left = ( (int(keypoints_left[left_idx].pt[0])), (int(keypoints_left[left_idx].pt[1])))
            image_right = ( (int(keypoints_right[right_idx].pt[0])), (int(keypoints_right[right_idx].pt[1])))
            
            goodmatch_pos.append([image_left, image_right])
        
        self.goodmatch_pos = goodmatch_pos
        return goodmatch_pos

    def compute_homography_matrix(self, matches_pos):# RANSAC

        destination_points = []
        source_points = []

        for destination_point, source_point in matches_pos:
            destination_points.append(list(destination_point))
            source_points.append(list(source_point))
        
        destination_points = np.array(destination_points)
        source_points = np.array(source_points)
            
        num_sample = len(matches_pos)
        threshold = 5
        num_iter = 10000
        num_subsample = 4
        max_inlier = 0
        BEST_H = None

        for i in range(num_iter):
            subsample_index = random.sample(range(num_sample), num_subsample)
            H = solve_homography(source_points[subsample_index], destination_points[subsample_index])

            num_inlier = 0
            for i in range(num_sample):
                if i not in subsample_index:
                    concate_coordina = np.hstack((source_points[i], [1]))
                    destination_coordina = H @ concate_coordina.T
                    
                    if destination_coordina[2] <= 1e-8:
                        continue
                    
                    destination_coordina = destination_coordina / destination_coordina[2]
                    if(np.linalg.norm(destination_coordina[:2] - destination_points[i]) < threshold):
                        num_inlier = num_inlier + 1

                if(max_inlier < num_inlier):
                    max_inlier = num_inlier
                    BEST_H = H
        
        self.H = BEST_H
        return BEST_H

    def create_mask(self, image_left, image_right, version):
        (height_left, width_left) = image_left.shape[:2]
        (height_right, width_right) = image_right.shape[:2]
        
        #height_panorama = height_left
        height_panorama = max(height_left, height_right)
        width_panorama = width_left + width_right
        offset = int(self.smoothing_window_size / 2)
        barrier = width_left - int(self.smoothing_window_size / 2)
        mask = np.zeros((height_panorama, width_panorama))

        if version == 'image_left':
            mask[:, barrier - offset:barrier + offset ] = np.tile(np.linspace(1, 0, 2 * offset).T, (height_panorama, 1))
            mask[:, :barrier - offset] = 1
        else:
            mask[:, barrier - offset:barrier + offset ] = np.tile(np.linspace(0, 1, 2 * offset).T, (height_panorama, 1))
            mask[:, barrier + offset:] = 1
        return cv2.merge([mask, mask, mask])

    def stitching_image(self, images):
        image_left, image_right = images[0], images[1]
        (height_left, width_left) = image_left.shape[:2]
        (height_right, width_right) = image_right.shape[:2]
        
        height_panorama = max(height_left, height_right)
        
        #height_panorama = max(height_left, height_right)
        width_panorama = width_left + width_right

        panoramal1 = np.zeros((height_panorama, width_panorama, 3))
        mask_left = self.create_mask(image_left, image_right, version='image_left')
       # cv2.imwrite("image_left", image_left)
        #self.show_result(image_left)
        panoramal1[0:height_left, 0:width_left, :] = image_left
        #print(mask_left)
        panoramal1 = np.multiply(panoramal1, mask_left, out=panoramal1, casting='unsafe')
        #cv2.imwrite("./image_left.jpg", panoramal1)
        #self.show_result(panoramal1)
        mask_right = self.create_mask(image_left, image_right, version='image_right')
        panoramal2 = cv2.warpPerspective(image_right, self.H, (width_panorama, height_panorama))
        
        #print(mask_right)
        panoramal2 = np.multiply(panoramal2, mask_right, out=panoramal2, casting='unsafe')
        
        #panoramal2 *= mask_right
        #cv2.imwrite("./image_right.jpg", panoramal2)
        #self.show_result(panoramal2)

        result = panoramal1 + panoramal2
       # self.show_result(result)
        rows, cols = np.where(result[:, :, 0]!=0)
        min_row, max_row = min(rows), max(rows) + 1
        min_col, max_col = min(cols), max(cols) + 1
        final_result = result[min_row:max_row, min_col:max_col, :]
        #cv2.imwrite("./image_result.jpg", final_result)
       # self.final_result = final_result
        return final_result

    def show_result(self, image):
        creat_im_window(f"result",image)
        im_show()

    def draw_key_points(self):
        for i in range(self.num_image):
            image = cv2.drawKeypoints(self.gray_image[i], self.kps_desc[i][0], self.image[i])
            creat_im_window(f"image {i}", image)
            cv2.imwrite(f"./key_points_{i}.jpg", image)

        im_show()

    def drawmatch(self, images, matches_pos):

        image_left, image_right = images[0], images[1]
        (height_left, width_left) = image_left.shape[:2]
        (height_right, width_right) = image_right.shape[:2]

        photo = np.zeros((max(height_left, height_right), width_left + width_right, 3),dtype="uint8")
        photo[0:height_left, 0:width_left] = image_left
        photo[0:height_right, width_left:] = image_right
        
        for (left_image_pos, right_image_pos) in matches_pos:
            pos_left = left_image_pos
            pos_right = right_image_pos[0] + width_left, right_image_pos[1]

            #cv2.circle(image, center_coordinates, radius, color, thickness)
            cv2.circle(photo, pos_left, 1, (0, 0, 255))
            cv2.circle(photo, pos_right, 1, (0, 255, 0)) 
            cv2.line(photo, pos_left, pos_right, (255, 0, 0), 1)
        
        #plt.imshow(photo)
        cv2.imwrite("./match_result.jpg", photo)
        creat_im_window("match_result", photo)
        im_show()

    def stiching_management(self):
        self.get_image()
        self.detecting_key_point()

        self.compute_homography_matrix(self.feature_matching(self.kps_desc[0], self.kps_desc[1]))
        image_1 = self.stitching_image([self.image[0], self.image[1]])
        cv2.imwrite("./test1/m1.jpg", image_1) #./image_0_1.jpg

        # self.compute_homography_matrix(self.feature_matching(self.kps_desc[2], self.kps_desc[3]))
        # image_3 = self.stitching_image([self.image[2], self.image[3]])
        # cv2.imwrite("./test1/m2.jpg", image_3) #./image_3_4.jpg
        # #cv2.imwrite("./test2/m2.jpg", image_3) #./image_3_4.jpg

        # self.compute_homography_matrix(self.feature_matching(self.kps_desc[5], self.kps_desc[6]))
        # image_4 = self.stitching_image([self.image[5], self.image[6]])
        # cv2.imwrite("./test1/m4.jpg", image_4) #./image_5_6.jpg
        
        # self.compute_homography_matrix(self.feature_matching(self.kps_desc[8], self.kps_desc[9]))
        # image_5 = self.stitching_image([self.image[8], self.image[9]])
        # cv2.imwrite("./test1/m6.jpg", image_5) #./image_8_9.jpg
        # cv2.imwrite("./test2/m4.jpg", image_5) #./image_8_9.jpg
        pass

    def stiching_management_1(self,path):
        self.get_image()
        self.detecting_key_point()

        self.compute_homography_matrix(self.feature_matching(self.kps_desc[0], self.kps_desc[1]))
        image_1 = self.stitching_image([self.image[0], self.image[1]])
        cv2.imwrite(f"{path}/m1.jpg", image_1)
        
        

In [5]:
#np.set_printoptions(threshold=np.inf)
image_1 = image_stitching("./test",2)
image_1.get_image()
image_1.detecting_key_point()
image_1.draw_key_points()
good_match = image_1.feature_matching(image_1.kps_desc[0], image_1.kps_desc[1])
image_1.drawmatch([image_1.image[0],image_1.image[1]],good_match)
# image_1.stiching_management()
# print("step 1 done")

# image_2 = image_stitching("./test1",2)
# image_2.stiching_management_1("./test2")
# print("step 2 done")

# image_1 = image_stitching("./test2",2)
# image_1.stiching_management_1("./test3")
# print("step 3 done")

# image_2 = image_stitching("./test3",2)
# image_2.stiching_management_1("./test4")
# print("step 4 done")

# image_2 = image_stitching("./test4",2)
# image_2.stiching_management_1("./test5")
# print("step 5 done")

# image_2 = image_stitching("./test5",2)
# image_2.stiching_management_1("./test6")
# print("step 6 done")

# image_2 = image_stitching("./test6",2)
# image_2.stiching_management_1("./test7")
# print("step 7 done")

# image_2 = image_stitching("./test7",2)
# image_2.stiching_management_1("./test8")
# print("step 8 done")

# image_2 = image_stitching("./test8",2)
# image_2.stiching_management_1("./test9")
# print("step 9 done")


./test/m1.jpg
./test/m2.jpg


In [3]:
for i in range(60):
    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
