In [None]:
import numpy as np
import pandas as pd
import cv2
import os
import collections
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
from tqdm import tqdm
from typing import List, Tuple, Dict, Any
import gc

# Draw Matches of descriptor matching using Mutual Nearest Neighbour Method (NN2W)

In [None]:
#####################
### SETTINGS
#####################
root_dir = '/home/mizzade/Workspace/diplom/code'
image_dir = os.path.join(root_dir, 'data')
data_dir = os.path.join(root_dir, 'outputs')
output_dir = '/home/mizzade/Workspace/diplom/outputs/descriptors/nn2way'

iname1 = '1.png'
iname2 = '2.png'
fname1 = '1_10000.csv'
fname2 = '2_10000.csv'
file_scheme = '_10000.csv'

collection_name = 'eisert'
collection_path_data = os.path.join(data_dir, collection_name)
collection_path_img = os.path.join(image_dir, collection_name)

kpts_thresholds = [1000, 5000, 10000]
desc_distance_thresholds = [0.7]


#####################
### FUNCTIONS
#####################

def normalize_descriptors(desc:np.array) -> np.array:
    """Creates unit vectors for each descriptor."""    
    _n = np.linalg.norm(desc, axis=1, ord=2) # Get norms of each vector
    _d = desc / _n.reshape(-1, 1)            # Build unit vector
    
    return _d

def nn_match_two_way(desc1, desc2, nn_thresh):
    """
    Performs two-way nearest neighbor matching of two sets of descriptors, such
    that the NN match from descriptor A->B must equal the NN match from B->A.

    Inputs:
      desc1 - NxM numpy matrix of N corresponding M-dimensional descriptors.
      desc2 - NxM numpy matrix of N corresponding M-dimensional descriptors.
      nn_thresh - Optional descriptor distance below which is a good match.

    Returns:
      matches - Lx3 numpy array, of L matches, where L <= N and each column i is
                a match of two descriptors, d_i in image 1 and d_j' in image 2:
                [d_i index, d_j' index, match_score]
    """
    # Check if descriptor dimensions match
    assert desc1.shape[1] == desc2.shape[1]

    # Return zero matches, if one image does not have a keypoint and
    # therefore no descriptors.
    if desc1.shape[0] == 0 or desc2.shape[0] == 0:
        return np.zeros((0, 3))
    if nn_thresh < 0.0:
        raise ValueError('\'nn_thresh\' should be non-negative')

    # Compute L2 distance. Easy since vectors are unit normalized.
    dmat = np.dot(desc1, desc2.T)
    dmat = np.sqrt(2-2*np.clip(dmat, -1, 1))

    # Get NN indices and scores.
    idx = np.argmin(dmat, axis=1)
    scores = dmat[np.arange(dmat.shape[0]), idx]
    
    # Threshold the NN matches.
    keep = scores < nn_thresh
   
    # Check if nearest neighbor goes both directions and keep those.
    idx2 = np.argmin(dmat, axis=0)
    keep_bi = np.arange(len(idx)) == idx2[idx]
    keep = np.logical_and(keep, keep_bi)
    idx = idx[keep]
    scores = scores[keep]
   
    # Get the surviving point indices.
    m_idx1 = np.arange(desc1.shape[0])[keep]
    m_idx2 = idx
    
    # Populate the final Nx3 match data structure.
    matches = np.zeros((int(keep.sum()), 3))
    matches[:, 0] = m_idx1
    matches[:, 1] = m_idx2
    matches[:, 2] = scores
    return matches

def save_figure(
  path_output:str,
  fig_name:str, 
  figure: mpl.figure.Figure,
  dpi:int=1200,
  tight_layout:bool=False) -> None:

    if not os.path.exists(path_output):
        os.makedirs(path_output, exist_ok=True)

    f_out = os.path.join(path_output, fig_name)
  
    if tight_layout:
        figure.savefig(f_out, bbox_inches='tight', pad_inches=0, dpi=dpi)
    else:
        figure.savefig(f_out, dpi=dpi)
        
#####################
### MAIN
#####################

# NOTE: To avoid memory errors, handle number of descriptors, detectors,
# keypoint threshold etc with car.

set_names = sorted([x for x in os.listdir(collection_path_img) if os.path.isdir(os.path.join(collection_path_img, x))])


for set_name in tqdm(set_names):
    
    # 1. Open folder of set
    set_path_2_desc = os.path.join(collection_path_data, set_name, 'descriptors')
    descriptor_names = sorted([x for x in os.listdir(set_path_2_desc) \
                               if os.path.isdir(os.path.join(set_path_2_desc, x))])
    
    # 2. Open the images
    imgs_path  = os.path.join(image_dir, collection_name, set_name)
    img1 = mpimg.imread(os.path.join(imgs_path, iname1))
    img2 = mpimg.imread(os.path.join(imgs_path, iname2))
     
    # Offset for keypoints in image 2, when we stack images
    # next to each other.
    offsets = np.array([img1.shape[1], 0])
    
    # Scalefactor for tcovdet
    _h, _w = img1.shape[:2]
    tcovdet_sf = 1.0
    if (_h * _w) > 1024:
        tcovdet_sf = 1. / (1024 * 768 / float(_h * _w))**(0.5)
        #tcovdet_sf = (_h * _w) / 1024.
    
    # Stack images next to each other.
    img = np.hstack([img1, img2])
    
    
    # Remove not needed single images
    del img1, img2, _h, _w
    gc.collect()
    
    for descriptor_name in descriptor_names:
        print(descriptor_name)
        set_path_2_dets = os.path.join(set_path_2_desc, descriptor_name)
        
        detector_names = sorted([x for x in os.listdir(set_path_2_dets) \
                                 if os.path.isdir(os.path.join(set_path_2_dets, x))])
        
        for detector_name in detector_names:
            print('\t', detector_name)
            set_path_2_files = os.path.join(set_path_2_dets, detector_name)
            file_names = sorted([x for x in os.listdir(set_path_2_files) if file_scheme in x])
            
            # 3. Open detector keypoints.
            kpts_path = os.path.join(collection_path_data, set_name, 'keypoints', detector_name)
            kpts1 = pd.read_csv(os.path.join(kpts_path, fname1), sep=',', comment='#', header=None, usecols=[0, 1]).values
            kpts2 = pd.read_csv(os.path.join(kpts_path, fname2), sep=',', comment='#', header=None, usecols=[0, 1]).values
            
            # 3.1 If detector is tcovdet, we have to scale keypoints.
            if detector_name == 'tcovdet':
                kpts1 = kpts1 * tcovdet_sf
                kpts2 = kpts2 * tcovdet_sf

            # 4. Open corresponding descriptors
            desc_path = os.path.join(collection_path_data, set_name, 'descriptors', descriptor_name, detector_name)
            desc1 = pd.read_csv(os.path.join(desc_path, fname1), sep=',', comment='#', header=None).values
            desc2 = pd.read_csv(os.path.join(desc_path, fname2), sep=',', comment='#', header=None).values
            desc1 = normalize_descriptors(desc1)
            desc2 = normalize_descriptors(desc2)
            
            for kpts_thresh in kpts_thresholds:
                # 4. Get the subset of descriptors and detectors
                # Make a copy, otherwise you overwrite the slices of 
                # the original.
                d1 = desc1[:kpts_thresh].copy()
                d2 = desc2[:kpts_thresh].copy()
                k1 = kpts1[:kpts_thresh].copy()
                k2 = kpts2[:kpts_thresh].copy()
                
                for desc_dist in desc_distance_thresholds:
                    # 5. Make Nearest Neighbour Two Way Matching for each desc_dist
                    # Column1 contains match indices of d1.
                    # Column2 contains match indices of d2.
                    # Column3 contains distance of descriptors.
                    res = nn_match_two_way(d1, d2, desc_dist)
    
                    # Apply offsets to keypoints in image 2
                    k2 += offsets
                        
                    idx1 = res[:, 0].astype(np.int)
                    idx2 = res[:, 1].astype(np.int)
                    
                    hits1 = (k1.copy())[idx1]
                    misses1 = np.delete(k1, idx1, 0)
                    hits2 = (k2.copy())[idx2]
                    misses2 = np.delete(k2, idx2, 0)
                    
                    hits = np.vstack([hits1, hits2])
                    misses = np.vstack([misses1, misses2])

                    # 5. Draw and save the image.
                    fig, ax = plt.subplots(1, 1)
                    # Draw image
                    ax.imshow(img)

                    # Draw misses (no matching keypoints)
                    ax.scatter(misses[:, 0], misses[:, 1], color='blue', marker='o', s=0.5, alpha=0.5)

                    # Draw lines of matches from image1 to image2
                    for idx, _ in enumerate(res):
                        pos = np.vstack([hits1[idx], hits2[idx]])
                        ax.plot(pos[:, 0], pos[:, 1], color='lawngreen', alpha=1, linewidth=0.5)

                    # Draw hits (matching keypoints)
                    ax.scatter(hits[:, 0], hits[:, 1], color='r', marker='o', s=0.5)
                    
                    ax.set_xticks([])
                    ax.set_yticks([])
                    
                    fig_name = '{}_{}__{}_{}__{}_{}.png' \
                        .format(descriptor_name, detector_name, collection_name, set_name, kpts_thresh, desc_dist)
                    
                    save_figure(output_dir, fig_name, fig, tight_layout=True)
                    plt.close(fig)
                    
                    # Free some memory
                    del k1, k2, d1, d2, hits1, hits2, misses1, misses2, res, pos
                    gc.collect()
                    
            # Free some memory
            del kpts1, kpts2, desc1, desc2
            gc.collect()
            
    # Free even more memory
    del img
    gc.collect()
                