In [None]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.neighbors import NearestNeighbors
from sklearn.manifold import TSNE

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.applications import *
import tensorflow_hub as hub

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

# Fix the random seeds
SEEDS=666

np.random.seed(SEEDS)
tf.random.set_seed(SEEDS)

In [None]:
IMAGE_SIZE = 160
BATCH_SIZE = 64
AUTO = tf.data.AUTOTUNE

In [None]:
# Image preprocessing utils
def preprocess_test(image):
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    image = tf.cast(image, tf.float32) / 255.0
    return image

In [None]:
validation_ds = (
    validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(buffer_size=AUTO)
)

In [None]:
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url, trainable=False)

In [None]:
class MyBiTModel(tf.keras.Model):
    def __init__(self, module):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(128)
        self.normalize = Lambda(lambda a: tf.math.l2_normalize(a, axis=1))
        self.bit_model = module
  
    def call(self, images):
        bit_embedding = self.bit_model(images)
        dense1_representations = self.dense1(bit_embedding)
        return self.normalize(dense1_representations)

In [None]:
model = MyBiTModel(module=module)

In [None]:
model.build(input_shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3))
model.load_weights("model_bit.h5")

In [None]:
images, labels = next(iter(validation_ds.take(1)))
print(images.shape, labels.shape)

In [None]:
validation_features = model.predict(images)
start = time.time()
neighbors = NearestNeighbors(n_neighbors=5,
    algorithm='brute',
    metric='euclidean').fit(validation_features)
print('Time taken: {:.5f} secs'.format(time.time() - start))

In [None]:
def plot_images(images, labels, distances):
    plt.figure(figsize=(20, 10))
    columns = 4
    for (i, image) in enumerate(images):
        ax = plt.subplot(len(images) / columns + 1, columns, i + 1)
        if i == 0:
            ax.set_title("Query Image\n" + "Label: {}".format(CLASSES[labels[i]]))
        else:
            ax.set_title("Similar Image # " + str(i) +
                         "\nDistance: " +
                         str(float("{0:.2f}".format(distances[i]))) + 
                         "\nLabel: {}".format(CLASSES[labels[i]]))
        plt.imshow(image)

In [None]:
for i in range(6):
    random_index = int(np.random.choice(images.shape[0], 1))
    distances, indices = neighbors.kneighbors(
        [validation_features[random_index]])
    
    # Don't take the first closest image as it will be the same image
    similar_images = [images[random_index]] + \
        [images[indices[0][i]] for i in range(1, 4)]
    similar_labels = [labels[random_index]] + \
        [labels[indices[0][i]] for i in range(1, 4)]
    plot_images(similar_images, similar_labels, distances[0])