ADAPTED: Bryn Ronalds, Insight 20B DS TO
"""

 image_retrieval.py  (author: Anson Wong / git: ankonzoid)

 We perform image retrieval using transfer learning on a pre-trained
 VGG image classifier. We plot the k=5 most similar images to our
 query images, as well as the t-SNE visualizations.

"""

In [1]:
import os
import numpy as np
import tensorflow as tf
from sklearn.neighbors import NearestNeighbors
import skimage.io
from skimage.transform import resize
from multiprocessing import Pool
from CV_transform_utils import apply_transformer, resize_img, normalize_img
from CV_plot_utils import plot_query_retrieval
from pickle import dump


Bad key "text.kerning_factor" on line 4 in
/Users/brynronalds/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


In [2]:
# Run mode: transfer learning -> vgg19
modelName = "vgg19" 
trainModel = True
parallel = True  # use multicore processing

In [3]:
# Make paths
PROJDIR = '/Users/brynronalds/Insight/proj_dir/'
TATDIR = os.path.join(PROJDIR,'data/processed/Inklusive_database/train')
TESTDIR = os.path.join(PROJDIR,'data/processed/Inklusive_database/test')
OUTDIR = os.path.join(PROJDIR,'data/processed/Inklusive_database/output')

dataTrainDir = TATDIR #os.path.join(os.getcwd(), "data", "train")
dataTestDir = TESTDIR #os.path.join(os.getcwd(), "data", "test")
outDir = OUTDIR

In [4]:
# Read images
def read_img(filePath):
    image = skimage.io.imread(filePath, as_gray=False)
    return image

def read_imgs_dir(dirPath, extensions, parallel=True):
    args = [os.path.join(dirPath, filename)
            for filename in os.listdir(dirPath)
            if any(filename.lower().endswith(ext) for ext in extensions)]
    if parallel:
        pool = Pool()
        imgs = pool.map(read_img, args)
        pool.close()
        pool.join()
    else:
        imgs = [read_img(arg) for arg in args]
    return imgs, args

extensions = [".jpg", ".jpeg"]
imgs_train, filenames_train = read_imgs_dir(dataTrainDir, extensions, parallel=parallel)
imgs_test, filenames_test = read_imgs_dir(dataTestDir, extensions, parallel=parallel)

In [5]:
# Load pre-trained VGG19 model + higher level layers
IMG_SIZE = 256
shape_img = (IMG_SIZE, IMG_SIZE, 3)
model = tf.keras.applications.VGG19(weights='imagenet', include_top=False,
                                        input_shape=shape_img)
#model.summary()
shape_img_resize = tuple([int(x) for x in model.input.shape[1:]])
input_shape_model = tuple([int(x) for x in model.input.shape[1:]])
output_shape_model = tuple([int(x) for x in model.output.shape[1:]])
n_epochs = None

# Print some model info
print("input_shape_model = {}".format(input_shape_model))
print("output_shape_model = {}".format(output_shape_model))

input_shape_model = (256, 256, 3)
output_shape_model = (8, 8, 512)


In [6]:
# Apply transformations to all images
class ImageTransformer(object):

    def __init__(self, shape_resize):
        self.shape_resize = shape_resize

    def __call__(self, img):
        img_transformed = resize_img(img, self.shape_resize)
        img_transformed = normalize_img(img_transformed)
        return img_transformed

transformer = ImageTransformer(shape_img_resize)
print("Applying image transformer to training images...")
imgs_train_transformed = apply_transformer(imgs_train, transformer, parallel=parallel)
print("Applying image transformer to test images...")
imgs_test_transformed = apply_transformer(imgs_test, transformer, parallel=parallel)

Applying image transformer to training images...
Applying image transformer to test images...


In [7]:
# Convert images to numpy array
X_train = np.array(imgs_train_transformed).reshape((-1,) + input_shape_model)
X_test = np.array(imgs_test_transformed).reshape((-1,) + input_shape_model)
print(" -> X_train.shape = {}".format(X_train.shape))
print(" -> X_test.shape = {}".format(X_test.shape))

 -> X_train.shape = (6703, 256, 256, 3)
 -> X_test.shape = (6, 256, 256, 3)


In [8]:
# Create embeddings using model
print("Inferencing embeddings using pre-trained model...")
E_train = model.predict(X_train)
E_train_flatten = E_train.reshape((-1, np.prod(output_shape_model)))
E_test = model.predict(X_test)
E_test_flatten = E_test.reshape((-1, np.prod(output_shape_model)))
print(" -> E_train.shape = {}".format(E_train.shape))
print(" -> E_test.shape = {}".format(E_test.shape))
print(" -> E_train_flatten.shape = {}".format(E_train_flatten.shape))
print(" -> E_test_flatten.shape = {}".format(E_test_flatten.shape))

Inferencing embeddings using pre-trained model...
 -> E_train.shape = (6703, 8, 8, 512)
 -> E_test.shape = (6, 8, 8, 512)
 -> E_train_flatten.shape = (6703, 32768)
 -> E_test_flatten.shape = (6, 32768)


In [9]:
dump(E_train, open("E_train.pkl", "wb"))
dump(E_train_flatten, open("E_train_flatten.pkl", "wb"))
dump(E_train, open("E_test.pkl", "wb"))
dump(E_train_flatten, open("E_test_flatten.pkl", "wb"))

In [10]:
# Fit kNN model on training images
print("Fitting k-nearest-neighbour model on training images...")
knn = NearestNeighbors(n_neighbors=5, metric="cosine")
knn.fit(E_train_flatten)

Fitting k-nearest-neighbour model on training images...


NearestNeighbors(metric='cosine')

In [12]:
filename = 'finalized_model.sav'
dump(knn, open(filename, 'wb'))

In [13]:
# Perform image retrieval on test images
print("Performing image retrieval on test images...")
for i, emb_flatten in enumerate(E_test_flatten):
    _, indices = knn.kneighbors([emb_flatten]) # find k nearest train neighbours
    img_query = imgs_test[i] # query image
    imgs_retrieval = [imgs_train[idx] for idx in indices.flatten()] # retrieval images
    imgs_path_retrieval = [filenames_train[idx] for idx in indices.flatten()] # retrieve filenames
    outFile = os.path.join(outDir, "{}_retrieval_{}.png".format(modelName, i))
    plot_query_retrieval(img_query, imgs_retrieval, imgs_path_retrieval, outFile)

Performing image retrieval on test images...


In [None]:
# # Plot t-SNE visualization
# print("Visualizing t-SNE on training images...")
# outFile = os.path.join(outDir, "{}_tsne.png".format(modelName))
# plot_tsne(E_train_flatten, imgs_train, outFile)