## Include libraries

In [1]:
import cv2
import numpy as np
import numpy.linalg as la
import re
import matplotlib.pyplot as plt
%matplotlib inline
from scipy import signal
import scipy

## Function

In [2]:
def find_match(des_1,des_2,k):
### Input: 
### des_1: descriptor, a list
### des_2: descriptor, a list
### k: number of best matchings in des_2 compared to des_1
###
### Output:
### a list that has dimension len(dis_1)*k, each entry is a cv2.DMatch() type.

    dist_mat = scipy.spatial.distance.cdist(des_1,des_2,'sqeuclidean') #calculate distance matrix
    matches = []
    for query_ind,query in enumerate(dist_mat):
        train_cand_ind_list = np.argsort(query)[:k] # find k elements with smallest distance
        query_matching = []
        for train_cand_ind in train_cand_ind_list:
            single_match = cv2.DMatch(query_ind,train_cand_ind,0,query[train_cand_ind])
            query_matching.append(single_match)
        matches.append(query_matching)
    return matches

## Inputs

In [3]:
IMG_DIR = 'image_0/'
IMG_NAME_1 = '000000.png'
IMG_NAME_2 = '000001.png'

## Find keypoints & descriptors

In [4]:
img_1 = cv2.imread(IMG_DIR + IMG_NAME_1)
img_2 = cv2.imread(IMG_DIR + IMG_NAME_2)

sift = cv2.xfeatures2d.SIFT_create()
kp1, des1 = sift.detectAndCompute(img_1,None)
kp2, des2 = sift.detectAndCompute(img_2,None)

## Sorting matches

In [5]:
# FLANN_INDEX_KDTREE = 0
# index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
# search_params = dict(checks=50)
# flann = cv2.FlannBasedMatcher(index_params,search_params)
matches = find_match(des1,des2,k=2)
good = []
pts1 = []
pts2 = []
kpt1_E = []
kpt2_E = []
for i,(m,n) in enumerate(matches):
    if m.distance < 0.8*n.distance:
        good.append(m)
        pts1.append(kp1[m.queryIdx].pt)
        pts2.append(kp2[m.trainIdx].pt)
        kpt1_E.append(kp1[m.queryIdx])
        kpt2_E.append(kp2[m.queryIdx])
        
pts1_E = np.asarray(pts1)
pts2_E = np.asarray(pts2)

## Plot matches

In [6]:
SIFT_1 = cv2.drawKeypoints(img_1, kpt1_E, None, flags = cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
SIFT_2 = cv2.drawKeypoints(img_2, kpt2_E, None, flags = cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)

cv2.imwrite('SIFT_matches_image1.png',SIFT_1)
cv2.imwrite('SIFT_matches_image2.png',SIFT_2)
cv2.imshow('Matches from image1',SIFT_1)
cv2.imshow('Matches from image2',SIFT_2)
cv2.waitKey()

-1