# SIFT + CNN features for image matching
SIFT seems to be narrowing down on the correct set really well, especially with Flann based KNN matching. Once we receive the smaller set, we can use CNN features to find the best match I think.

In [9]:
import os
import cv2
import pickle
import numpy as np
from glob import glob
from ntpath import basename
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from IPython.display import clear_output

from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input

In [7]:
# Inception model to create features
inception_model = InceptionV3(weights='imagenet')
inception_model = Model(inception_model.inputs, inception_model.layers[-2].output)

In [3]:
# Load CNN features and SIFT descriptors created earlier
with open('../../saved_data/22 Jun/feature_stack.pkl', 'rb') as f:
    feature_stack = pickle.load(f)
    
with open('../../saved_data/22 Jun/file_feature_map.pkl', 'rb') as f:
    file_feature_map = pickle.load(f)
    
with open('../../saved_data/22 Jun/file_sift_map.pkl', 'rb') as f:
    file_sift_map = pickle.load(f)

In [13]:
# Function to evaluate model performance

class ObjectMatcher(object):
    
    def __init__(self, file_feature_map, file_sift_map):
        self.file_feature_map = file_feature_map
        self.file_sift_map = file_sift_map
        self.paths = list(file_sift_map.keys())
        
        # Sift object and Flann matcher
        FLANN_INDEX_KDTREE = 0
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=50)
        
        self.sift_ = cv2.xfeatures2d.SIFT_create()
        self.flann_ = cv2.FlannBasedMatcher(index_params, search_params)
        
        
    def find_candidates(self, img_path, top_k=5):
        """ Find candidates with SIFT features """
        # Dictionary to store good matches
        file_matches_dict = {}
        # Create sift features for input
        img = cv2.imread(img_path, 0)
        _, img_desc = self.sift_.detectAndCompute(img, None)
        
        for path in tqdm(self.paths):
            desc = self.file_sift_map[path]
            matches = self.flann_.knnMatch(desc, img_desc, k=2)
            good = 0
            for m, n in matches:
                if m.distance < 0.7 * n.distance:
                    good += 1
            file_matches_dict.update({path: good})
            
        # Select top 5 in all matches
        clear_output()
        sorted_paths = sorted(self.paths, key=lambda x: file_matches_dict[x])
        return sorted_paths[:top_k]
    
    
    def get_features(self, model, img_path):
        """ Returns features extracted by CNN model """
        
        img = load_img(img_path, target_size=(299, 299))
        img = img_to_array(img)
        img = np.expand_dims(img, axis=0)
        x = preprocess_input(img)
        features = model(x)
        return features.numpy()
    
    
    def find_best_match(self, img_features, candidate_paths):
        """ Uses CNN features to narrow down on best match """
        
        feature_dict = {f: self.file_feature_map[f] for f in candidate_paths}
        matcher = lambda x: np.linalg.norm(img_features - self.file_feature_map[x])
        ranked_paths = sorted(candidate_paths, key=matcher)
        best_match = ranked_paths[0]
        best_3_match = ranked_paths[:3]
        return best_match, best_3_match
    
    
    def evaluate_model(self, root_dir):
        """ Finds the accuracy of the model """
        
        correct, correct3 = 0, 0
        folders = os.listdir(root_dir)
        folders.remove('test_images')
        num_files = 0
        
        for fol in folders:
            for path in tqdm(glob(root_dir+'/'+fol+'/*.jpg')):
                num_files += 1
                cnd_paths = self.find_candidates(path)
                features = self.get_features(inception_model, path)
                best, best3 = self.find_best_match(features, cnd_paths)
                if basename(best) == basename(path):
                    correct += 1
                if basename(path) in [basename(p) for p in best3]:
                    correct3 += 1
                    
        accuracy = 100. * correct / num_files
        accuracy3 = 100. * correct3 / num_files
        
        # Accuracy
        clear_output()
        print("Accuracy: {:.2f}".format(accuracy))
        print("Accuracy in top 3: {:.2f}".format(accuracy3))
        
    
    def test_model(self, img_paths):
        """ Tests the model for a few images """
        
        for img_path in tqdm(img_paths):
            cnd_paths = self.find_candidates(img_path)
            features = self.get_features(inception_model, img_path)
            _, best3 = self.find_best_match(features, cnd_paths)
            best3_imgs = [cv2.imread(path, 0) for path in best3]

            # Show image and best matches
            test_img = cv2.imread(img_path, 0)
            plt.imshow(test_img, cmap='gray')
            plt.show()
            fig = plt.figure()
            for i in range(3):
                fig.add_subplot(1, 3, i+1)
                plt.imshow(best3_imgs[i], cmap='gray')
            plt.tight_layout()
            plt.show()
            print("\n------------------------------------------\n")

In [19]:
# Choose some paths for testing
paths = list(file_feature_map.keys())
idx = np.random.randint(0, len(paths)-1, size=3)
chosen_paths = [paths[i] for i in idx]

# Testing 
obj_match = ObjectMatcher(file_feature_map, file_sift_map)
obj_match.test_model(chosen_paths)

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2224), HTML(value='')))

KeyboardInterrupt: 