### Using the E_train predictions, apply model to input images

In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.neighbors import NearestNeighbors
import skimage.io
from src.CV_transform_utils import resize_img, normalize_img
from src.plot_funcs import plot_query_retrieval
from pickle import dump, load

In [2]:
datadir = '/Users/brynronalds/Insight/proj_dir/data/'
modeldir = '/Users/brynronalds/Insight/proj_dir/models/'

filen = os.path.join(datadir,"cleaned/train_filenames.pkl")
train_filenames = load(open(filen, "rb"))

imgfilen = os.path.join(datadir,"cleaned/train_images.pkl")
train_images = load(open(imgfilen, "rb"))

img_test_path = os.path.join(datadir,'processed/Inklusive_database/test/butterfly.jpg')
test_image = skimage.io.imread(img_test_path, as_gray=False)

dfn = os.path.join(datadir,'processed/Inklusive_database/train/tattoo_info.csv')
tattoo_df = pd.read_csv(dfn)

In [3]:
# Load pre-trained VGG19 model + higher level layers
IMG_SIZE = 256
shape_img = (IMG_SIZE, IMG_SIZE, 3)
model = tf.keras.applications.VGG19(weights='imagenet', include_top=False,
                                        input_shape=shape_img)
output_shape_model = tuple([int(x) for x in model.output.shape[1:]])

In [4]:
img_transformed = resize_img(test_image,shape_img)
img_transformed = normalize_img(img_transformed)
X_test = np.array(img_transformed).reshape((-1,) + shape_img)

In [5]:
# Create embeddings using model
E_test = model.predict(X_test)
E_test_flatten = E_test.reshape((-1, np.prod(output_shape_model)))

del model, img_transformed, X_test

In [6]:
# Load the trained knn model (k=5), apply to test image: 
k = '10'
filename = os.path.join(modeldir,'finalized_knn' + k +'_model.sav')
knn = load(open(filename, 'rb'))
for i, emb_flatten in enumerate(E_test_flatten):
    _, indices = knn.kneighbors([emb_flatten]) # find k nearest train neighbours
    imgs_retrieval = [train_images[idx] for idx in indices.flatten()] 
    imgs_path_retrieval = [train_filenames[idx] for idx in indices.flatten()] 

In [7]:
# Need the artist/studio info from the dataframe
artist = []
studio = []
for fpath in imgs_path_retrieval:
    _, filen = os.path.split(fpath)
    sl_df = tattoo_df[tattoo_df['filename'] == filen]
    artist.append(sl_df.tail(1)['artist handle'].values[0])
    studio.append(sl_df.tail(1)['studio name'].values[0])
    

In [8]:
result_images = (imgs_retrieval, artist, studio)
outname = 'k' + k +'_similartats_for_butterfly.png'
plot_query_retrieval(result_images, outname)

5
