In [1]:
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from enum import Enum
import os

In [2]:
def preprocess_image(img):
    img = img.resize((64,64),resample=1)
    arr = np.array(img)
    arr = np.expand_dims(arr, axis=0)
    return arr/255.0

def compute_distance(vec1,vec2,norm1,norm2):
    arr1 = vec1.ravel()
    arr2 = vec2.ravel()
    return arr1.dot(arr2)/(norm1*norm2)

class Channel(Enum):
    RED_ONLY = 0
    GREEN_ONLY = 1
    BLUE_ONLY = 2
    ALL = 3    

In [3]:
# Loads Keras model
model = tf.keras.models.load_model('dogs_cats_cnn.h5')

In [4]:
# Builds truncated model to retrieve selected layer output
truncated_model = tf.keras.Model(inputs=model.input,
                                 outputs=model.get_layer('dense').output)
truncated_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_input (InputLayer)    [(None, 64, 64, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 62, 62, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 31, 31, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 29, 29, 32)        9248      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
flatten (Flatten)            (None, 6272)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               802944

In [5]:
# Reads all images paths from dataset and shuffle them
images = []
path = '../../datasets/classification/cats and dogs/'
for root,_,files in os.walk(path):
    for file in files:
        if '.jpg' in file:
            images.append(os.path.join(root,file))
images = np.random.permutation(images)

In [6]:
# Precompute targets norms
num_targets = 6
targets = []
for i in range(num_targets):
    img = preprocess_image(Image.open(images[i]))
    vec = truncated_model.predict(img).ravel()
    targets.append((img[0,:,:,:],vec,np.sqrt(vec.T.dot(vec))))

UnknownError:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[node model/conv2d/Conv2D (defined at <ipython-input-6-42aeeb03731a>:6) ]] [Op:__inference_distributed_function_965]

Function call stack:
distributed_function


In [None]:
fig, axes = plt.subplots(1,num_targets)
fig.set_size_inches(3*num_targets,3)
fig.set_tight_layout(tight=0.1)
for i,ax in enumerate(axes.ravel()):
    ax.imshow(targets[i][0])
    ax.set_xticks(())
    ax.set_yticks(())

In [None]:
# Find closest neigbors
num_queries = 5000
distances = np.zeros((num_targets,num_queries))
for j in range(num_targets,num_targets+num_queries):
    img = preprocess_image(Image.open(images[j]))
    query = truncated_model.predict(img).ravel()
    query_norm = np.sqrt(query.T.dot(query))
    for i,target in enumerate(targets):
        distances[i,j-num_targets] = compute_distance(target[1],query,target[2],query_norm)
num_closest = 5
closest = distances.argsort(axis=1)[:,-1:-num_closest-1:-1] + num_targets

In [None]:
# Plots nearest neighbors
channel = Channel.ALL
cols = num_closest + 1
rows = num_targets
fig, axes = plt.subplots(rows,cols)
fig.set_size_inches(3*cols,3*cols)
fig.set_tight_layout(tight=0.1)
for i,ax in enumerate(axes.ravel()):
    column = i % (cols)
    row = int(i / cols)
    if column == 0:
        img = targets[row][0] 
        arr = np.expand_dims(img, axis=0)
    else:
        arr = preprocess_image(Image.open(images[closest[row,column-1]]))
        img = arr[0,:,:,:]
    
    if channel == Channel.RED_ONLY:
        temp = img.copy()
        temp[:,:,1] = 0
        temp[:,:,2] = 0
    elif channel == Channel.GREEN_ONLY:
        temp = img.copy()
        temp[:,:,0] = 0
        temp[:,:,2] = 0
    elif channel == Channel.BLUE_ONLY:
        temp = img.copy()
        temp[:,:,0] = 0
        temp[:,:,1] = 0
    else:
        temp = img
        
    ax.imshow(temp)   
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(model.predict(arr)[0][0])