In [None]:
import numpy as np
import cv2
import yaml
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from utils import utils

from superpoint.superpoint import SuperPointFrontend
from utils.visualization import visual_atten
from superpoint.utils import get_query_img_name, get_refer_img_name
from utils.visualization import read_failed_cases

# Read the data

In [None]:
# read failed cases from specific file saved by pickle.dump()
with open("config/eval_config.yaml", "r") as fin:
    config = yaml.safe_load(fin)

DATASET = config["dataset"]["name"]
ATTEN_PATH = (
    config["results_path"]
    + DATASET
    + "/"
    + config["output_1"]["method"]
    + "/"
    + config["output_1"]["anchor_select_policy"]
    + "/failed_cases"
)
FULL_GEO_PATH = (
    config["results_path"]
    + DATASET
    + "/"
    + config["output_2"]["method"]
    + "/"
    + config["output_2"]["anchor_select_policy"]
    + "/failed_cases"
)

In [None]:
# read the filed cases for attention patch
atten_failed_cases = read_failed_cases(ATTEN_PATH)

# read the filed cases for attention patch
all_failed_cases = read_failed_cases(FULL_GEO_PATH)

print(f"The number of failed cases in atten patch is {len(atten_failed_cases)}")
print(f"The number of failed cases in cross match is {len(all_failed_cases)}")

In [None]:
QUERY_PATH = config["dataset"]["Root"] + config["dataset"]["query_dir"]
REF_PATH = config["dataset"]["Root"] + config["dataset"]["refer_dir"]

# Plot the failed cases
# display_failed_images(atten_failed_cases, QUERY_PATH, REF_PATH)
# display_failed_images(all_failed_cases, QUERY_PATH, REF_PATH)

## Compare the common and differences between all and atten

In [None]:
from utils.visualization import DatasetComparer

comparer = DatasetComparer(atten_failed_cases, all_failed_cases)
comparer.compare_datasets()
same_wrong_preds = comparer.are_wrong_predictions_same()

same_wrong_preds

In [None]:
different_data_atten = comparer.find_data_for_different_queries()
different_data_atten

# Visualization of the attention

In [None]:
query_descriptors = []
refer_descriptors = []
wrong_pred_descriptors = []
query_anchors = []

query_rgbs = []
refer_rgbs = []
wrong_pred_rgbs = []

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3:  # pragma: no cover
    print("Warning: OpenCV 3 is not installed")

In [None]:
print("==> Loading pre-trained network...")
fe = SuperPointFrontend(
    weights_path=config["model"]["weights_path"],
    nms_dist=config["model"]["nms_dist"],
    conf_thresh=config["model"]["conf_thresh"],
    nn_thresh=config["model"]["nn_thresh"],
    cuda=config["model"]["cuda"],
)

print("===> Successfully loaded pre-trained network.")

for i in tqdm(range(len(different_data_atten))):
    # print('==> Refer: ' + str(refer + refer_index_offset))
    try:
        refer_img = cv2.imread(
            REF_PATH + "/" + get_refer_img_name("SPED", different_data_atten[i][2][0])
        )
        query_ori = cv2.imread(
            QUERY_PATH + "/" + get_query_img_name("SPED", different_data_atten[i][0])
        )
        wrong_pred_img = cv2.imread(
            REF_PATH + "/" + get_query_img_name("SPED", different_data_atten[i][1])
        )

    except (IOError, ValueError) as e:
        refer_img = None
        print("Exception! \n \n \n \n")

    if config["output_1"]["anchor_select_policy"] == "conv_filter":
        edges_query = cv2.Canny(query_ori, 300, 1000, apertureSize=5)
        # deresolution to 64 * 64
        edges_query = cv2.resize(edges_query, (32, 32), interpolation=cv2.INTER_AREA)

    refer_img = cv2.resize(refer_img, (256, 256), interpolation=cv2.INTER_AREA)
    refer_img = refer_img.astype("float32") / 255.0
    refer_img = cv2.cvtColor(refer_img, cv2.COLOR_BGR2GRAY)

    # for query and wrong prediction
    query_img = cv2.resize(query_ori, (256, 256), interpolation=cv2.INTER_AREA)
    query_img = query_img.astype("float32") / 255.0
    query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2GRAY)

    wrong_pred_img = cv2.resize(
        wrong_pred_img, (256, 256), interpolation=cv2.INTER_AREA
    )
    wrong_pred_img = wrong_pred_img.astype("float32") / 255.0
    wrong_pred_img = cv2.cvtColor(wrong_pred_img, cv2.COLOR_BGR2GRAY)

    keypoints, desc = fe.run_with_point(refer_img)
    refer_descriptors.append(desc)

    # anchors
    anchors = np.array([], dtype=int)

    keypoints = keypoints[:2, :]
    keypoints = keypoints.transpose()
    keypoints = [[item // 8 for item in subl] for subl in keypoints]
    keypoints = [list(t) for t in set(tuple(element) for element in keypoints)]
    anchors = np.array(
        [utils.idx_table[int(item[0]), int(item[1])] for item in keypoints]
    )

    query_anchors.append(anchors)

    desc = fe.run(query_img)
    query_descriptors.append(desc)

    desc = fe.run(wrong_pred_img)
    wrong_pred_descriptors.append(desc)

In [None]:
for i in range(len(different_data_atten)):
    query_descriptor = query_descriptors[i]
    refer_descriptor = refer_descriptors[i]
    wrong_pred_descriptor = wrong_pred_descriptors[i]

    anchors = query_anchors[i]

    query_img_1, refer_q_img, score_1 = visual_atten(
        query_descriptor,
        refer_descriptor,
        anchors,
        i,
        True,
        QUERY_PATH,
        REF_PATH,
        different_data_atten,
    )
    query_img_2, wrong_pred, score_2 = visual_atten(
        query_descriptor,
        wrong_pred_descriptor,
        anchors,
        i,
        False,
        QUERY_PATH,
        REF_PATH,
        different_data_atten,
    )

    print(f"Score for query is {score_1}")
    print(f"Score for wrong prediction is {score_2}")
    fig, axs = plt.subplots(1, 4, figsize=(12, 3))
    # Set titles for each subplot
    titles = ["query", "reference", "query", "wrong prediction"]
    # Plot the subimages and set titles

    axs[0].imshow(query_img_1)
    axs[0].set_title(titles[0])
    axs[1].imshow(refer_q_img)
    axs[1].set_title(titles[1])
    axs[2].imshow(query_img_2)
    axs[2].set_title(titles[2])
    axs[3].imshow(wrong_pred)
    axs[3].set_title(titles[3])
    plt.show()