In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import multiprocessing as mp
from tensorflow.keras.metrics import CosineSimilarity
from glob import glob
from tqdm import tqdm

from siamese_network.data_pipeline import input_dataset
from siamese_network.model import embedding
from siamese_network.model import siamese_model

In [None]:
# get list of paths to anchor and positive images
anchor_images = glob("../data/left/*")
positive_images = glob("../data/right/*")
len(anchor_images), len(positive_images)

# Checkout model inference

In [None]:
batch_size = 16
ds_train, ds_validation = input_dataset.create_triplet_dataset(anchor_images, positive_images, batch_size=batch_size)

In [None]:
for x in ds_train.take(1):
    break

In [None]:
# implements the train and test step to use during training
model = siamese_model.SiameseModel()

# run one inference to be able to load weights
model(x);

# load latest weights
model.load_weights(glob('../data/results/checkpoints/*.hdf5')[-1])

In [None]:
embeddings_anchor = model.embedding(x[0])
embeddings_positive = model.embedding(x[1])
embeddings_negative = model.embedding(x[2])

In [None]:
cs = CosineSimilarity()

In [None]:
n_triplets = 16

fig, axes = plt.subplots(n_triplets, 3, figsize=(12, 4.4*n_triplets))
fig.tight_layout(w_pad=1)

for i in range(n_triplets):
    ax_anchor, ax_pos, ax_neg = axes[i]
#     ap_similarity = cs(embeddings_anchor[i], embeddings_positive[i]).numpy()
#     an_similarity = cs(embeddings_anchor[i], embeddings_negative[i]).numpy()
    ap_similarity = siamese_model.l2_distance(embeddings_anchor[i], embeddings_positive[i]).numpy()
    an_similarity = siamese_model.l2_distance(embeddings_anchor[i], embeddings_negative[i]).numpy()
    ax_anchor.imshow(x[0][i])
    ax_pos.set_title(ap_similarity)
    ax_pos.imshow(x[1][i])
    ax_neg.set_title(an_similarity)
    ax_neg.imshow(x[2][i])
axes[0, 0].set_title("Anchor", fontsize=16);

# Check how often the correct image is chosen

By correct here meaning that the distance between anchor and positive is smaller than the distance between anchor and negative.

In [None]:
distances = []
for batch in tqdm(ds_validation):
    distances.append(model(batch))

In [None]:
ap_distances = tf.concat([d[0] for d in distances], axis=0).numpy()
an_distances = tf.concat([d[1] for d in distances], axis=0).numpy()

print("Proportion of triplets where correct image chosen: {:.2f}%".format((ap_distances < an_distances).mean() * 100))

# Create embedding matrix from anchors

In [None]:
ds_anchor = input_dataset.create_images_dataset(anchor_images, map_preprocessing_fnc=True)
ds_anchor = ds_anchor.batch(32)

In [None]:
embeddings_anchor_path = 'embeddings_anchor.npy'

try:
    embeddings_anchor = np.load(embeddings_anchor_path, allow_pickle=True)
    print("Loaded embeddings from '{}'".format(embeddings_anchor_path))
except FileNotFoundError:
    print("Producing embeddings...")
    embeddings_anchor = model.embedding.predict(ds_anchor, verbose=1)
    np.save(embeddings_anchor_path, embeddings_anchor, allow_pickle=True)

embeddings_anchor.shape

In [None]:
def calculate_l2_distances_to_one_image(query_image_path):
    """Calculates the distances between the query_image and all 
    the anchor images.
    """
    img = tf.expand_dims(input_dataset.load_and_preprocess_image(query_image_path), axis=0)
    query_embedding = model.embedding(img)
    distances = list(map(lambda vector: siamese_model.l2_distance(vector, query_embedding), embeddings_anchor))
    return tf.concat(distances, axis=0).numpy()


def find_top_similar_images(query_image_path, top_k=3):
    """Finds the top_k anchor images that are most similar to the 
    query image.
    """
    distances = calculate_l2_distances_to_one_image(query_image_path)
    top_indices = np.argpartition(distances, top_k)[:top_k]
    top_distances = distances[top_indices]
    top_paths = [anchor_images[i] for i in top_indices]
    return top_paths, top_distances


def visualize_top_similarity_images(query_image_path, query_anchor_path, most_similar_images_paths, most_similar_images_distances):
    """Plots the query image, the corresponding anchor image and the 
    top_k most similar images among the anchor images based on the
    calculated distances.
    """
    n_cols = 4
    top_k = len(most_similar_images_paths)
    n_rows = int(np.ceil(top_k / n_cols)) + 1
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, n_rows*4.4))
    fig.tight_layout(w_pad=1)

    axes[0, 0].imshow(input_dataset.load_and_preprocess_image(query_image_path))
    axes[0, 1].imshow(input_dataset.load_and_preprocess_image(query_anchor_path))
    axes[0, 1].set_title(
        siamese_model.l2_distance(model.embedding(tf.expand_dims(load_and_preprocess_image(query_image_path), axis=0)),
                                  model.embedding(tf.expand_dims(load_and_preprocess_image(query_anchor_path), axis=0))).numpy()[0]
    )
    for i in range(2, n_cols):
        axes[0, i].set_visible(False)

    sorted_image_paths = [p for _, p in sorted(zip(most_similar_images_distances, most_similar_images_paths))]
    sorted_distances = sorted(most_similar_images_distances)

    axes_raveled = np.ravel(axes)[n_cols:]
    for i in range(top_k):
        axes_raveled[i].imshow(load_and_preprocess_image(sorted_image_paths[i]))
        axes_raveled[i].set_title(sorted_distances[i])
    for i in range(len(axes_raveled) - top_k):
        axes_raveled[-i-1].set_visible(False)

In [None]:
top_k = 8
i = np.random.randint(len(positive_images))
q_image_path = positive_images[i]
q_anchor_path = anchor_images[i]
closest_image_paths, closest_distances = find_top_similar_images(q_image_path, top_k=top_k)

In [None]:
visualize_top_similarity_images(q_image_path, q_anchor_path, closest_image_paths, closest_distances)

# Find proximity of two images

In [None]:
def find_rank(idx, query_image_path):
    distances_arr = calculate_l2_distances_to_one_image(query_image_path)
    rank = (distances_arr < distances_arr[idx]).sum()
    return rank

In [None]:
ranks = []
for i, q_im_path in tqdm(enumerate(positive_images), total=len(positive_images)):
    ranks.append(find_rank(i, q_im_path))