In [None]:
from matplotlib import pyplot as plt
import cv2
import numpy as np
print('opencv版本: ', cv2.__version__)


def bgr_rgb(img):
    (r, g, b) = cv2.split(img)
    return cv2.merge([b, g, r])

def preprocess(img11, img21, filter=None, meanfilter_time1=0, meanfilter_time2=0):
    img1 = img11.copy()
    img2 = img21.copy()

    if filter:
        if filter == "gray":
            img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
            img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        if filter == "oldyellow":
            img2 = retro_style(img2)

    if meanfilter_time1 > 0:
        for i in range(meanfilter_time1):
            img1 = cv2.blur(img1, (3, 3))
    if meanfilter_time2 > 0:
        for i in range(meanfilter_time2):
            img2 = cv2.blur(img2, (3, 3))

    plt.figure(1, dpi=200)
    plt.axis("off")
    plt.imshow(img1)
    plt.figure(2, dpi=200)
    plt.axis("off")
    plt.imshow(img2)

    return img1,img2

def detect(img11, img21, sf="sift", filter=None, meanfilter_time1=0, meanfilter_time2=0, limit=0.6):

    img1 = img11.copy()
    img2 = img21.copy()

    if filter:
        if filter == "gray":
            img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
            img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        if filter == "oldyellow":
            img2 = retro_style(img2)

    if meanfilter_time1 > 0:
        for i in range(meanfilter_time1):
            img1 = cv2.blur(img1, (3, 3))
    if meanfilter_time2 > 0:
        for i in range(meanfilter_time2):
            img2 = cv2.blur(img2, (3, 3))

    # plt.figure(1, dpi=200)
    # plt.axis("off")
    # plt.imshow(img1)
    # plt.figure(2, dpi=200)
    # plt.axis("off")
    # plt.imshow(img2)

    if sf=="sift":
        sift = cv2.xfeatures2d.SIFT_create()
    elif sf=="surf":
        sift = cv2.xfeatures2d.SURF_create()
    elif sf=="orb":
        pass
    else:
        print("无效算法")
        return None,None
    
    # 使用SIFT查找关键点和描述符
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)
    # BFMatcher 使用默认参数
    bf = cv2.BFMatcher()
    rawmatches = bf.knnMatch(des1, des2, k=2)

    kps1 = np.float32([kp.pt for kp in kp1])
    kps2 = np.float32([kp.pt for kp in kp2])
    matches = []
    good = []
    # 遍历初始匹配点
    for m in rawmatches:
        # 应用ratio测试，选出符合条件的匹配点(Lowe's ratio test)
        # 取图像1中的某个关键点，并找出其与图像2中距离最近的前两个关键点，在这两个关键点中，若最近的距离除以次近的距离小于某个阈值，则接受这一对匹配点。
        # 实验结果表明ratio取值（limit）在0.5~0.6为最佳
        if len(m) == 2 and m[0].distance < m[1].distance * limit:
            matches.append((m[0].trainIdx, m[0].queryIdx))
            good.append([m[0]])
     
    # 使用cv2.drawMatchesKnn将匹配点画出
    img3 = cv2.drawMatchesKnn(img11, kp1, img21, kp2, good, None, flags=2)
    
    # 计算出一个单应性变换至少需要4对匹配点
    if len(matches) > 4:
        # 构造这两组点坐标为对应形式
        ptsA = np.float32([kps1[i] for (_, i) in matches])
        ptsB = np.float32([kps2[i] for (i, _) in matches])

        # 计算两组点之间的单应性变换矩阵以及每个匹配点的状态
        (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, 4)

    else:
        print("没有足够的匹配点")
        return (matches, None, None, bgr_rgb(img3))
    return (matches, H, status, bgr_rgb(img3))


def opt(image_a, image_b, sf="sift", filter=None, meanfilter_time1=1, meanfilter_time2=1, limit=0.6):

    (matches, H, status, matchesimg) = detect(
        image_a, image_b, sf, filter, meanfilter_time1, meanfilter_time2, limit)

    if H.any():
        # 使用单应性变换矩阵进行原图的透视变换，将图1投影到图2大小的初始图上
        result = cv2.warpPerspective(image_a, H, (image_b.shape[1], image_b.shape[0]))
        result = bgr_rgb(result)
        img = bgr_rgb(image_b)
        
        # 将图二上对应位置替换成投影后的图1
        for i in range(image_b.shape[0]):
            for j in range(image_b.shape[1]):
                if sum(result[i, j]) > 0:
                    img[i, j] = result[i, j]
        
        return matchesimg, img
    
    print("失败")
    return matchesimg, None

def orb_detect(image_a, image_b):
    orb = cv2.ORB_create()
    kp1, des1 = orb.detectAndCompute(image_a, None)
    kp2, des2 = orb.detectAndCompute(image_b, None)
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

    matches = bf.match(des1, des2)
    rawmatches = sorted(matches, key=lambda x: x.distance)
    img3 = cv2.drawMatches(image_a, kp1, image_b, kp2,
                           rawmatches[:10], None, flags=2)
    goodPoints = matches[:10] if len(matches) > 10   else matches[:]

    src_pts = np.float32([kp1[m.queryIdx].pt for m in goodPoints]).reshape(-1, 1, 2)
    dst_pts = np.float32([kp2[m.trainIdx].pt for m in goodPoints]).reshape(-1, 1, 2)
    
    M, mask = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 4)
    
    return (matches, M, mask ,bgr_rgb(img3))
# 泛黄老照片效果
def retro_style(img):
    img2 = img.copy()
    height, width, n = img.shape
    for i in range(height):
        for j in range(width):
            b = img[i, j][0]
            g = img[i, j][1]
            r = img[i, j][2]
            # 计算新的图像中的RGB值
            B = int(0.273 * r + 0.535 * g + 0.131 * b)
            G = int(0.347 * r + 0.683 * g + 0.167 * b)
            R = int(0.395 * r + 0.763 * g + 0.188 * b) 
            # 约束图像像素值，防止溢出
            img2[i, j][0] = max(0, min(B, 255))
            img2[i, j][1] = max(0, min(G, 255))
            img2[i, j][2] = max(0, min(R, 255))
    return img2



In [None]:
#快速调参函数
if __name__ == '__main__': 
    # 加载图像
    for i in range(20):
        # image_a = cv2.imread('picture{}/1.png'.format(1))
        # image_b = cv2.imread('picture{}/2.png'.format(1))
        # sf,filter,meanfiler_time1,meanfiler_time2,ratiolimit = np.load('picture{}/parm.npy'.format(i+1),allow_pickle=True)
        # print(sf,filter,meanfiler_time1,meanfiler_time2,ratiolimit)
        # matchesimg, img = opt(image_a, image_b, sf, filter, int(meanfiler_time1),int(meanfiler_time2),float(ratiolimit))
        image_a = cv2.imread('1.png')
        image_b = cv2.imread('2.png')
        matchesimg, img = opt(image_a, image_b, "sift", "gray", 1, 1, 0.8+i/100)

        # plt.figure(3, dpi=200)
        # plt.axis("off")
        # plt.imshow(matchesimg)
        # plt.savefig("picture{}/matches.png".format(i+1))

        if img.any():
            plt.figure(i+1)
            plt.imshow(img)
            plt.axis("off")
            plt.title("ratiolimit={}".format(0.8+i/100))
            # plt.savefig("picture{}/result.png".format(i+1))

        plt.show()
    print("done")


In [12]:
#保存参数
np.save('parm.npy',("sift", "gray", 1, 1, 0.6))
sf,filter,meanfiler_time1,meanfiler_time2,ratiolimit=np.load('parm.npy',allow_pickle=True)
print(sf,filter,meanfiler_time1,meanfiler_time2,ratiolimit)

sift gray 1 1 0.6
