In [None]:
import numpy as np
import cv2
import os
import yaml
import pickle
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from tqdm.notebook import tqdm

from superpoint.superpoint import SuperPointFrontend
from superpoint.utils import get_query_img_name, get_refer_img_name

# Read the data

In [None]:
# read failed cases from specific file saved by pickle.dump()
with open('config/eval_config.yaml', 'r') as fin:
    config = yaml.safe_load(fin)

DATASET = config["dataset"]["name"]
ATTEN_PATH = config["results_path"] + DATASET + "/" + config["output_1"]["method"] + "/" + config["output_1"]["anchor_select_policy"] + "/failed_cases"
FULL_GEO_PATH = config["results_path"] + DATASET + "/" + config["output_2"]["method"] + "/" + config["output_2"]["anchor_select_policy"] + "/failed_cases"

def read_failed_cases(path):
    with open(path, 'rb') as f:
        failed_cases = pickle.load(f)
    return failed_cases

In [None]:
# read the filed cases for attention patch
atten_failed_cases = read_failed_cases(ATTEN_PATH)

# read the filed cases for attention patch
all_failed_cases = read_failed_cases(FULL_GEO_PATH)

print(f"The number of failed cases in atten patch is {len(atten_failed_cases)}")
print(f"The number of failed cases in cross match is {len(all_failed_cases)}")

In [None]:
QUERY_PATH = config["dataset"]["Root"] + config["dataset"]["query_dir"]
REF_PATH = config["dataset"]["Root"] + config["dataset"]["refer_dir"]

# get the images of failed cases
def get_failed_images(failed_cases, query_path, ref_path):
    query_images = []
    wrong_pred = []
    ref_images = []
    for case in failed_cases:
        # the image name is in 7 digits and in the format of .jpg
        query_images.append(os.path.join(query_path, str(case[0]).zfill(7) + '.jpg'))
        wrong_pred.append(os.path.join(query_path, str(case[1]).zfill(7) + '.jpg'))
        ref_images.append(os.path.join(ref_path, str(case[2][0]).zfill(7) + '.jpg'))
    return query_images, wrong_pred, ref_images

def display_failed_images(failed_cases):
    query_images, wrong_pred, ref_images = get_failed_images(failed_cases, QUERY_PATH, REF_PATH)

    labels = ["query image", "wrong prediction", "reference image"]

    # display the images of failed cases
    for i in range(len(query_images)):
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))
        query_image = mpimg.imread(query_images[i])
        wrong_pred_image = mpimg.imread(wrong_pred[i])
        ref_image = mpimg.imread(ref_images[i])
        axes[0].imshow(query_image)
        axes[0].set_title(labels[0])
        axes[1].imshow(wrong_pred_image)
        axes[1].set_title(labels[1])
        axes[2].imshow(ref_image)
        axes[2].set_title(labels[2])
        plt.axis('off')
        plt.show()
        
# Plot the failed cases
# display_failed_images(atten_failed_cases)
# display_failed_images(all_failed_cases)

## Compare the common and differences between all and atten

In [None]:
class DatasetComparer:
    def __init__(self, data1, data2):
        self.data1 = data1
        self.data2 = data2
        self.common_queries = []
        self.different_queries = []

    def compare_datasets(self):
        # Extract the query data from both datasets
        queries1 = [item[0] for item in self.data1]
        queries2 = [item[0] for item in self.data2]

        # Find common queries
        self.common_queries = list(set(queries1).intersection(queries2))

        # Find different queries
        self.different_queries = list(set(queries1).symmetric_difference(queries2))
        
        print(f"The number of failed cases in common is {len(self.common_queries)}")
        print(f"The number of failed cases in different is {len(self.different_queries)}")

    def are_wrong_predictions_same(self):
        # Do the comparison if lists are empty
        if not self.common_queries and not self.different_queries:
            self.compare_datasets()
        
        # Initialize a dictionary to store the wrong predictions for common queries
        wrong_preds_dict = {}

        # Populate the dictionary with common queries and their corresponding wrong predictions
        for query in self.common_queries:
            wrong_preds_dict[query] = (self.find_wrong_prediction(self.data1, query),
                                       self.find_wrong_prediction(self.data2, query))

        # Check if wrong predictions are the same for common queries
        same_wrong_preds = {query: wrong_preds for query, wrong_preds in wrong_preds_dict.items() if
                            wrong_preds[0] == wrong_preds[1]}

        return same_wrong_preds

    def find_wrong_prediction(self, data, query):
        for item in data:
            if item[0] == query:
                return item[1]
        return None
    
    def find_data_for_different_queries(self):
        different_queries = self.different_queries

        # Initialize a list to store data for different queries in data1
        different_data_in_data1 = []

        # Find data associated with different queries in data1
        for query in different_queries:
            for item in self.data1:
                if item[0] == query:
                    different_data_in_data1.append(item)

        return different_data_in_data1

In [None]:
comparer = DatasetComparer(atten_failed_cases, all_failed_cases)
comparer.compare_datasets()
same_wrong_preds = comparer.are_wrong_predictions_same()

In [None]:
same_wrong_preds

In [None]:
different_data_atten = comparer.find_data_for_different_queries()
different_data_atten

# Visualization of the attention

In [None]:
threshold = 0.55
reproj_err = 3
params = [threshold, reproj_err]
method = 'AttnPatch'

query_index_offset = 0
refer_index_offset = 0

query_descriptors = []
refer_descriptors = []
wrong_pred_descriptors = []
query_anchors = []

query_rgbs = []
refer_rgbs = []
wrong_pred_rgbs = []

pos_ptr = np.array([[-99, -98, -97, -96, -95, -94, -93],
                    [-67, -66, -65, -64, -63, -62, -61],
                    [-35, -34, -33, -32, -31, -30, -29],
                    [-3, -2, -1, 0, 1, 2, 3],
                    [29, 30, 31, 32, 33, 34, 35],
                    [61, 62, 63, 64, 65, 66, 67],
                    [93, 94, 95, 96, 97, 98, 99]])

idx_table = np.reshape(np.array([val for val in range(0, 32 * 32)]), (32, 32))
cache_table = np.zeros((1024, 2), dtype=int)
for cnt in range(1024):
    ridx = int(cnt / 32)
    cidx = int(cnt % 32)
    cache_table[cnt] = np.array([ridx, cidx])

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3:  # pragma: no cover
    print('Warning: OpenCV 3 is not installed')

In [None]:
print("==> Loading pre-trained network...")
fe = SuperPointFrontend(weights_path=config["model"]["weights_path"],
                        nms_dist=config["model"]["nms_dist"],
                        conf_thresh=config["model"]["conf_thresh"],
                        nn_thresh=config["model"]["nn_thresh"],
                        cuda=config["model"]["cuda"])

print("===> Successfully loaded pre-trained network.")

for i in tqdm(range(len(different_data_atten))):
    # print('==> Refer: ' + str(refer + refer_index_offset))
    try:
        refer_img = cv2.imread(REF_PATH + '/' + get_refer_img_name("SPED", different_data_atten[i][2][0]))
        query_img = cv2.imread(QUERY_PATH + '/' + get_query_img_name("SPED", different_data_atten[i][0]))
        wrong_pred_img = cv2.imread(REF_PATH + '/' + get_query_img_name("SPED", different_data_atten[i][1]))

    except(IOError, ValueError) as e:
        refer_img = None
        print('Exception! \n \n \n \n')

    refer_img = cv2.resize(refer_img, (256, 256), interpolation=cv2.INTER_AREA)
    refer_img = (refer_img.astype('float32') / 255.)
    refer_img = cv2.cvtColor(refer_img, cv2.COLOR_BGR2GRAY)
    
    # for query and wrong prediction
    query_img = cv2.resize(query_img, (256, 256), interpolation=cv2.INTER_AREA)
    query_img = (query_img.astype('float32') / 255.)
    query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2GRAY)
    
    wrong_pred_img = cv2.resize(wrong_pred_img, (256, 256), interpolation=cv2.INTER_AREA)
    wrong_pred_img = (wrong_pred_img.astype('float32') / 255.)
    wrong_pred_img = cv2.cvtColor(wrong_pred_img, cv2.COLOR_BGR2GRAY)
    
    desc = fe.run(refer_img)
    refer_descriptors.append(desc)
    
    # anchors
    anchors = np.array([], dtype=int)
    query_self_sim = np.dot(desc.transpose(), desc)
    query_self_sim = np.sum(query_self_sim, axis=0)
    query_self_sim = np.reshape(query_self_sim, (32, 32))
    
    for row in range(8):
        for col in range(8):
            pos = np.argmin(query_self_sim[(4 * row):(4 * (row + 1)), (4 * col):(4 * (col + 1))])
            tmp_anchor = np.reshape(idx_table[(4 * row):(4 * (row + 1)), (4 * col):(4 * (col + 1))], -1)[pos]
            anchors = np.append(anchors, tmp_anchor)
            
    query_anchors.append(anchors)
    
    desc = fe.run(query_img)
    query_descriptors.append(desc)
    
    desc = fe.run(wrong_pred_img)
    wrong_pred_descriptors.append(desc)

In [None]:
# Attention patch visualization
def visual_atten(query_descriptor_in, refer_descriptor_in, anchors_in, idx, is_refer):
    score_matrix = np.dot(query_descriptor_in.transpose()[anchors_in],
                                      refer_descriptor_in)
    score_max_vector = np.max(score_matrix, axis=1)
    where_max_matrix = np.argmax(score_matrix, axis=1)

    where = [idx for idx, val in enumerate(score_max_vector) if val > threshold]
    query_where = anchors_in[where]
    refer_where = where_max_matrix[where]

    query_pos = np.array([], dtype=int)
    refer_pos = np.array([], dtype=int)

    for cnt in range(query_where.shape[0]):
        query_pos = np.append(query_pos, query_where[cnt] + pos_ptr)
        refer_pos = np.append(refer_pos, refer_where[cnt] + pos_ptr)

    qpos_idx = np.where(query_pos >= 0)
    query_pos = query_pos[qpos_idx]
    refer_pos = refer_pos[qpos_idx]
    qpos_idx = np.where(query_pos < 1023)
    query_pos = query_pos[qpos_idx]
    refer_pos = refer_pos[qpos_idx]
    rpos_idx = np.where(refer_pos >= 0)
    query_pos = query_pos[rpos_idx]
    refer_pos = refer_pos[rpos_idx]
    rpos_idx = np.where(refer_pos < 1023)
    query_pos = query_pos[rpos_idx]
    refer_pos = refer_pos[rpos_idx]

    query_roi = np.append(query_where, query_pos)
    refer_roi = np.append(refer_where, refer_pos)

    query_rois = query_descriptor_in.T[query_roi]
    refer_rois = refer_descriptor_in.T[refer_roi]

    mul_score = np.sum(np.multiply(query_rois, refer_rois), axis=1)
    select_roi_idx = np.where(mul_score > threshold)
    query_roi = query_roi[select_roi_idx]
    refer_roi = refer_roi[select_roi_idx]
    _, unique_indices, _, _ = np.unique(query_roi, return_index=True,
                                        return_inverse=True,
                                        return_counts=True)
    query_roi = query_roi[unique_indices]
    refer_roi = refer_roi[unique_indices]

    query_2d_idx = cache_table[query_roi]
    refer_2d_idx = cache_table[refer_roi]
    
    score = 0

    if query_2d_idx.shape[0] > 3:
        _, mask = cv2.findHomography(refer_2d_idx, query_2d_idx, cv2.FM_RANSAC,
                                        ransacReprojThreshold=reproj_err)

        inlier_index_keypoints = refer_2d_idx[mask.ravel() == 1]
        inlier_count = inlier_index_keypoints.shape[0]
        score = inlier_count / query_descriptor_in.shape[0]
        
    query_rgb = cv2.imread(QUERY_PATH + '/' + get_query_img_name("SPED", different_data_atten[idx][0]))
    
    if is_refer:
        refer_rgb = cv2.imread(REF_PATH + '/' + get_refer_img_name("SPED", different_data_atten[idx][2][0]))
    else:
        refer_rgb = cv2.imread(QUERY_PATH + '/' + get_query_img_name("SPED", different_data_atten[idx][1]))
        
    query_rgb = cv2.resize(query_rgb, (256, 256), interpolation=cv2.INTER_AREA)
    query_rgb = (query_rgb.astype('float32') / 255.)
    
    refer_rgb = cv2.resize(refer_rgb, (256, 256), interpolation=cv2.INTER_AREA)
    refer_rgb = (refer_rgb.astype('float32') / 255.)
    
    
    query_img_labels = cv2.cvtColor(query_rgb, cv2.COLOR_RGB2BGR)
    refer_img_labels = cv2.cvtColor(refer_rgb, cv2.COLOR_RGB2BGR)

    for cnt in range(query_roi.shape[0]):
        query_where_ = query_roi[cnt]
        refer_where_ = refer_roi[cnt]

        cv2.rectangle(query_img_labels,
                        (int(query_where_ % 32)*8,
                        int(query_where_/32)*8),
                        ((int(query_where_ % 32)+1)*8,
                        (int(query_where_/32)+1)*8),
                        (0, 128, 0), 1)

        cv2.rectangle(refer_img_labels,
                        (int(refer_where_ % 32)*8,
                        int(refer_where_/32)*8),
                        ((int(refer_where_ % 32)+1)*8,
                        (int(refer_where_/32)+1)*8),
                        (0, 128, 0), 1)
        
    return query_img_labels, refer_img_labels, score

In [None]:
for i in range(len(different_data_atten)):
    
    query_descriptor = query_descriptors[i]
    refer_descriptor = refer_descriptors[i]
    wrong_pred_descriptor = wrong_pred_descriptors[i]
    
    anchors = query_anchors[i]
    
    query_img_1, refer_q_img, score_1 = visual_atten(query_descriptor, refer_descriptor, anchors, i, True)
    query_img_2, wrong_pred, score_2 = visual_atten(query_descriptor, wrong_pred_descriptor, anchors, i, False)
    
    print(f"Score for query is {score_1}")
    print(f"Score for wrong prediction is {score_2}")
    fig, axs = plt.subplots(1, 4, figsize=(12, 3))
    # Set titles for each subplot
    titles = ["query", "reference", "query", "wrong prediction"]
    # Plot the subimages and set titles
    
    axs[0].imshow(query_img_1)
    axs[0].set_title(titles[0])
    axs[1].imshow(refer_q_img)
    axs[1].set_title(titles[1])
    axs[2].imshow(query_img_2)
    axs[2].set_title(titles[2])
    axs[3].imshow(wrong_pred)
    axs[3].set_title(titles[3])
    plt.show()

# Filter

In [None]:
# Using edge detector to visualize the attention patch
for case in atten_failed_cases:
    # read query, wrong prediction and reference images
    query_img = cv2.imread(QUERY_PATH + '/' + get_query_img_name("SPED", case[0]))
    wrong_pred_img = cv2.imread(REF_PATH + '/' + get_query_img_name("SPED", case[1]))
    refer_img = cv2.imread(REF_PATH + '/' + get_refer_img_name("SPED", case[2][0]))
    
    # edge detection
    edges_query = cv2.Canny(query_img, 300, 1000, apertureSize=5)
    edges_wrong_pred = cv2.Canny(wrong_pred_img, 300, 1000, apertureSize=5)
    edges_refer = cv2.Canny(refer_img, 300, 1000, apertureSize=5)
    
    # plot the images
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    # Set titles for each subplot
    titles = ["query", "wrong prediction", "reference"]
    # Plot the subimages and set titles
    axs[0].imshow(edges_query)
    axs[0].set_title(titles[0])
    axs[1].imshow(edges_wrong_pred)
    axs[1].set_title(titles[1])
    axs[2].imshow(edges_refer)
    axs[2].set_title(titles[2])
    
    plt.show()