In [2]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

%matplotlib inline

In [58]:
def gen_data(img01, img02):
    sift = cv2.xfeatures2d.SIFT_create()
    matcher = cv2.BFMatcher_create()
    
    kp01, des01 = sift.detectAndCompute(img01, None)
    kp02, des02 = sift.detectAndCompute(img02, None)
    matches = matcher.match(des01, des02)
    
    return kp01, kp02, matches

def cal_per_error(src_pts, dst_pts, M, n):
    '''
        重新计算重投影误差
    '''
    
    pts = np.ones([n, 3])
    pts[:, :2] = src_pts
    
    per_pts = pts @ M
    for i, _ in enumerate(src_pts):
        src_pts[i] = pts[i, :2] / pts[i, 2]
    
    error = src_pts - dst_pts
    
    return dst_pts.sum()

In [79]:
def RANSAC(kp01, kp02, matches, W=0.5, P=0.995, thre=0.7, n=4):
    '''
        @param pts: (-1, 1, 2)
        @param matches
        @param W: 内点比例
        @param P: 内点的自信力阈值
        @param thre: 内点阈值
        @param n: 选择计算单应矩阵的点对的数量
        
        return good matches by ransac
    '''
    num_matches = len(matches)
    
    # 初始化内点，随机选取 1/3 匹配的点作为内点
    in_pair = [matches[i] for i in np.random.choice(np.arange(num_matches), int(num_matches*0.33))]
    it = 0
    k = 10
    
    while it < k:
        m_idxs = np.random.choice(np.arange(num_matches), int(n)).astype(np.int32)
        
        mm = [matches[i] for i in m_idxs]  # m matches
        cal_ps01 = np.float32([kp01[m.queryIdx].pt for m in mm]).reshape([-1, 2])
        cal_ps02 = np.float32([kp02[m.queryIdx].pt for m in mm]).reshape([-1, 2])
        
        
        M, mask = cv2.findHomography(cal_ps01, cal_ps02)
        # todo: calculate the error for perspective pts again
        error = cal_per_error(cal_ps01, cal_ps02, M, n)
        
        
        W = len(in_pair) / num_matches
        k = np.log(1- P) / np.log(1 - np.power(W, n))
        
        if error < thre:
            in_pair += [m for m in mm if m not in in_pair]
        
        it += 1
        
    return in_pair

In [80]:
img01 = cv2.imread('../data/match_0.jpg', 0)
img02 = cv2.imread('../data/match_1.jpg', 0)

kp01, kp02, matches = gen_data(img01, img02)
ransac_matches = RANSAC(kp01, kp02, matches)

print(len(ransac_matches))

(537, 128) (1262, 128)
177
