# Feature Matching

This notebook evaluates feature matching performance for a number of test scenarios.

TODO: There are images that clearly have bad homographies, but are added together nevertheless.

TODO: Refactor the code in here to use the code that the mosaics use.

# Setup

## Imports

In [None]:
# import copy
import glob
import itertools
import os
import time

In [None]:
import cv2
import numpy as np
import pandas as pd
# from sklearn.model_selection import train_test_split
# from sklearn.pipeline import Pipeline
# from sklearn.utils import check_random_state
import tqdm.notebook
import yaml

In [None]:
import matplotlib
import matplotlib.pyplot as plt
# import matplotlib.patches as patches
import seaborn as sns
sns.set_style('white')

In [None]:
from night_horizons import utils, raster, features

## Settings

In [None]:
with open('./config.yml', "r", encoding='UTF-8') as file:
    settings = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
local_settings = {
    # Filetree settings
    'test_images_dir': '../test_data/feature_matching/',
    'src_format': 'src_{:03d}.tiff',
    'dst_format': 'dst_{:03d}.tiff',

    # Feature matching options
    'feature_detectors': [
        ('ORB', {}),
        ('SIFT', {}),
        # Still marked as patented in the opencv version I'm using.
        # (cv2.xfeatures2d.SURF_create, {}),
        ('AKAZE', {}),
        # ('BRISK', {}),
        # Does not seem to be fully implemented in OpenCV
        # ('FastFeatureDetector', {}),
        # Does not seem to be fully implemented in OpenCV
        # ('MSER', {}),
    ],
    'feature_matchers': [
        # TODO: Explore other feature matchers.
        ('BFMatcher', {}),
        # ('FlannBasedMatcher', {}),
        # ('BFMatcher', {'k': [10,]}),
        # TODO: Try Grid-based Motion Statistics. Very fast, but more complicated.
    ],
    'transform_param_grid': {
        'homography_method': [
            cv2.RANSAC,
            # cv2.RHO,
            # These don't really show promise
            # cv2.LMEDS,
            # 0,
        ],
        # 'ransacReprojThreshold': np.arange(1, 10),
        # 'maxIters': [100, 1000, 2000, 10000],
        'n_matches_used': [10, 100, 500, None],
        'dark_frame_brightness': [0.03, ],
        'dark_frame_percentile': [0.99, ],
    },

    # Analysis parameters
    'det_min': 0.6,
    'det_max': 2.,
    'n_images': 10000,
    'show_images': True,
}
settings.update(local_settings)

## Parse Settings

In [None]:
# Initialize the feature detectors
feature_detectors = []
for subsettings in settings['feature_detectors']:

    if len(subsettings[1]) == 0:
        feature_detectors.append(subsettings)
        continue
    
    # Generate all permutations of values
    param_grid = subsettings[1]
    keys, values = zip(*param_grid.items())
    permutations = itertools.product(*values)
    
    list_addition = [
        (
            subsettings[0],
            dict(zip(keys, permutation))
        )
        for permutation in permutations
    ]
    feature_detectors += list_addition

In [None]:
# Initialize the feature matchers
feature_matchers = []
for subsettings in settings['feature_matchers']:

    if len(subsettings[1]) == 0:
        feature_matchers.append(subsettings)
        continue
    
    # Generate all permutations of values
    param_grid = subsettings[1]
    keys, values = zip(*param_grid.items())
    permutations = itertools.product(*values)
    
    list_addition = [
        (
            subsettings[0],
            dict(zip(keys, permutation))
        )
        for permutation in permutations
    ]
    feature_matchers += list_addition

In [None]:
# Get transform kwargs

# Generate all permutations of values
param_grid = settings['transform_param_grid']
keys, values = zip(*param_grid.items())
permutations = itertools.product(*values)

transform_kwargs = [
    dict(zip(keys, permutation))
    for permutation in permutations
]

## Code

In [None]:
class JoinSearcher:

    def __init__(self, src_fp, dst_fp):

        self.src_fp = src_fp
        self.src_image = raster.Image.open(src_fp)
        self.src_img = self.src_image.img_int
        self.dst_fp = dst_fp
        self.dst_image = raster.Image.open(dst_fp)
        self.dst_img = self.dst_image.img_int

    def show_before(self, img='semitransparent_img'):

        subplot_mosaic = [['dst_img', 'src_img']]
        fig = plt.figure(figsize=(20,10))
        ax_dict = fig.subplot_mosaic(subplot_mosaic)
        
        ax = ax_dict['dst_img']
        self.dst_image.show(ax=ax, img=img)
        ax.set_title('dst')
        
        ax = ax_dict['src_img']
        self.src_image.show(ax=ax, img=img)
        ax.set_title('src')
        
        plt.tight_layout()

    def show_after(self, img='semitransparent_img'):
        
        # View results
        subplot_mosaic = [['warped_img', 'blended_img']]
        fig = plt.figure(figsize=(20,10))
        ax_dict = fig.subplot_mosaic(subplot_mosaic)
        
        ax = ax_dict['warped_img']
        self.warped_image.show(ax=ax, img=img)
        ax.set_title('warped')
        
        ax = ax_dict['blended_img']
        self.blended_image.show(ax=ax, img=img)
        ax.set_title('blended')
        
        plt.tight_layout()

    def warp_and_blend(self, M):
        
        # Warp image
        self.warped_img = features.ImageJoiner.warp(self.src_img, self.dst_img, M)
        self.warped_image = raster.Image(self.warped_img)

        # Blend images
        self.blended_img = features.ImageJoiner.blend(self.warped_img, self.dst_img)
        self.blended_image = raster.Image(self.blended_img)

    def grid_search(self, feature_detectors, feature_matchers, transform_kwargs, log_keys=['abs_det_M', 'dark_frac']):

 
        n_fd = len(feature_detectors)
        n_fm = len(feature_matchers)
        n_t = len(transform_kwargs)
        n_tot = n_fd * n_fm * n_t


        rows = []
        # TODO: Somehow the number of iterations doesn't match with pbar
        with tqdm.notebook.tqdm(total=n_tot) as pbar:
            for i, fd_settings in enumerate(feature_detectors):
                pbar.update(1)
                
                for j, fm_settings in enumerate(feature_matchers):
                    pbar.update(1)
    
                    for k, t_kwargs in enumerate(transform_kwargs):
    
                        image_joiner = features.ImageJoiner(
                            feature_detector=fd_settings[0],
                            feature_detector_options=fd_settings[1],
                            feature_matcher=fm_settings[0],
                            feature_matcher_options=fm_settings[1],
                            log_keys=log_keys,
                            debug_mode=True,
                            **t_kwargs
                        )
    
                        return_code, results_ijk, log = image_joiner.join(self.src_img, self.dst_img)
    
                        # Store results
                        row = {
                            'return_code': return_code,
                            'i_fd': i,
                            'j_fm': j,
                            'k_tk': k,
                        }
                        # Make blanks for things we want to log no matter what
                        row_defaults = {key: np.nan for key in log_keys}
                        row.update(row_defaults)
                        # Actual update
                        row.update(results_ijk)
                        row.update(log)
                        rows.append(row)

                        pbar.update(1)
        
        df = pd.DataFrame(rows)

        # Measure of how warped the image is
        df['warp_factor'] = np.abs(np.log10(np.abs(df['abs_det_M'])))

        # Identify the best set of parameters
        df['valid_M'] = df['return_code'] == 'success'
        valid_df = df.loc[df['valid_M']]
        if len(valid_df) > 0:
            best_ind = valid_df.index[valid_df['duration'].argmin()]
        else:
            best_ind = df.index[df['warp_factor'].argmin()]
        best_row = df.loc[best_ind]
        t_best_ind = pd.Timedelta(settings['n_images'] * best_row['duration'], unit='second')

        # Results
        gs_info = {
            'n_valid': len(valid_df),
            'best_ind': best_ind,
            't_best_ind': t_best_ind,
            'best_fd': feature_detectors[best_row['i_fd']],
            'best_fm': feature_matchers[best_row['j_fm']],
            'best_tk': transform_kwargs[best_row['k_tk']],
        }

        if len(valid_df) > 0:
            print(
                f'''Grid search complete.
                    {gs_info['n_valid']} valid results.
                    Best valid time was {best_row['duration']:.2g} seconds, for an estimated total time of {gs_info['t_best_ind']}.
                    The best feature detector was {gs_info['best_fd']}
                    The best feature matcher was {gs_info['best_fm']}
                    The best transform kwargs were {gs_info['best_tk']}
                '''
            )
        else:
            print(
                f'''No successes found. Closest det_min was {best_row['abs_det_M']:.3g}
                '''
            )

        self.df = df
        self.best_row = best_row
        self.gs_info = gs_info

        return df, best_row, gs_info


# Image Pairs

## A Particular Set

In [None]:
i = 3
src_fp = os.path.join(settings['test_images_dir'], settings['src_format'].format(i))
dst_fp = os.path.join(settings['test_images_dir'], settings['dst_format'].format(i))
js = JoinSearcher(src_fp, dst_fp)

In [None]:
if settings['show_images']:
    js.show_before(img='img')

In [None]:
df, best_row, grid_search_results = js.grid_search(feature_detectors, feature_matchers, transform_kwargs)

In [None]:
if 'M' in best_row:
    js.warp_and_blend(best_row['M'])

    if settings['show_images']:
        js.show_after()

## All Sets

In [None]:
src_fps = sorted(glob.glob(os.path.join(settings['test_images_dir'], 'src_*.tiff')))
dst_fps = sorted(glob.glob(os.path.join(settings['test_images_dir'], 'dst_*.tiff')))

In [None]:
results = []
for i, src_fp in enumerate(src_fps):
    print(f'i = {i} / {len(src_fps)}')
    
    dst_fp = dst_fps[i]

    fc = JoinSearcher(src_fp, dst_fp)
    
    df, best_row, grid_search_results = fc.grid_search(feature_detectors, feature_matchers, transform_kwargs)
    
    df['set'] = i
    results.append(df)

## Summarize

In [None]:
df = pd.concat(results, ignore_index=True)

In [None]:
df['ijk'] = 'i' + df['i_fd'].astype(str) + '_j' + df['j_fm'].astype(str) + '_k' + df['k_tk'].astype(str)

In [None]:
valid_df = df.loc[df['valid_M']]
valid_or_dark_df = df.loc[df['return_code'].isin(['success', 'dark_frame'])]

In [None]:
n_sets = pd.unique(df['set']).size

### Overview

In [None]:
sns.scatterplot(
    data=df,
    x='duration',
    y='warp_factor',
    hue='valid_M',
)

ax = plt.gca()
ax.axhline(np.abs(np.log10(settings['det_min'])))
ax.set_ylim(0, ax.get_ylim()[1])

### Identify Promising Parameters

In [None]:
# Feature detectors that work across all image pairs
n_valid_sets = valid_or_dark_df.groupby('i_fd')['set'].nunique()
promising_fd = [feature_detectors[_] for _ in n_valid_sets.index[n_valid_sets==n_sets]]
promising_fd

In [None]:
# Feature matchers that work across all image pairs
n_valid_sets = valid_or_dark_df.groupby('j_fm')['set'].nunique()
promising_fm = [feature_matchers[_] for _ in n_valid_sets.index[n_valid_sets==n_sets]]
promising_fm

In [None]:
# Transform parameters that work across all image pairs
n_valid_sets = valid_or_dark_df.groupby('k_tk')['set'].nunique()
promising_t_kwargs = [transform_kwargs[_] for _ in n_valid_sets.index[n_valid_sets==n_sets]]
method_map = {
    getattr(cv2, method): method
    for method in ['RANSAC', 'LMEDS', 'RHO']
}
promising_t_kwargs = [
    {
        key:(method_map[value] if key == 'method' else value)
        for key, value in t_kwargs.items()
    } for t_kwargs in promising_t_kwargs
]
promising_t_kwargs

In [None]:
# And now the combinations that are fully good
ijk_groups = valid_or_dark_df.groupby('ijk')
n_valid_sets = ijk_groups['set'].nunique()
is_good = n_valid_sets == n_sets
good_ijks = n_valid_sets.index[is_good]

In [None]:
if is_good.sum() > 0:
    # Convert into a dataframe
    good_df = ijk_groups.first().loc[good_ijks]
    good_df['duration'] = ijk_groups['duration'].mean().loc[good_ijks]
    good_df = good_df.sort_values('duration')
    best_row = good_df.iloc[0]

    # Print the best (quickest while still valid) combination
    print(
        feature_detectors[best_row['i_fd']],
        feature_matchers[best_row['j_fm']],
        transform_kwargs[best_row['k_tk']]
    )
else:
    print('No single set of parameters works for all images.')