This work is inspired blog post of Maciej D. Korzec https://towardsdatascience.com/recommending-similar-images-using-pytorch-da019282770c

In [None]:
# Some imports, which were needed to run the code on the Puhti supercomputer
# Install these via pip if you don't have them already
'''
import sys
!{sys.executable} -m pip install torchvision
!{sys.executable} -m pip install tqdm
!{sys.executable} -m pip install numpy
!{sys.executable} -m pip install pandas
!{sys.executable} -m pip install matplotlib
'''

In [None]:
# Imports
import os
import torch
import pandas as pd
import numpy as np
import pickle
import csv
from PIL import Image
from torchvision import transforms, models
from tqdm import tqdm
from numpy.testing import assert_almost_equal

In [None]:

# Needed input dimensions for the CNN
# PyTorch's documentation suggests resolution of at least 224 x 224
inputDim = (224,224)

# Directory, from where the images to be analyzed are taken
# Change accordingly to your needs and folder structure
inputDir = "ill_copy_new_clip/microbio"

# Output directory for the similar images
# Change accordingly to your needs and folder structure
inputDirCNN = "CNN"

In [None]:
os.makedirs(inputDirCNN, exist_ok = True)

transformationForCNNInput = transforms.Compose([transforms.Resize(inputDim)])

# This will take reasonably large amount of time.
# Could be investigated, if can be made faster

for imageName in os.listdir(inputDir):
    I = Image.open(os.path.join(inputDir, imageName))
    newI = transformationForCNNInput(I)

    # Copy the rotation information metadata from original image and save, else your transformed images may be rotated
    newI.save(os.path.join(inputDirCNN, imageName))
    
    newI.close()
    I.close()


In [None]:
# The class for the resnet
class Img2VecResnet18():
    def __init__(self):
        
        # If you dont have an access to a GPU, use the CPU version
        self.device = torch.device("cpu") 
        # self.device = torch.device("cuda") 
        self.numberFeatures = 512
        self.modelName = "resnet-18"
        self.model, self.featureLayer = self.getFeatureLayer()
        self.model = self.model.to(self.device)
        self.model.eval()
        self.toTensor = transforms.ToTensor()
        
        # These values are suggested by PyTorch's documentation
        # normalize the resized images as expected by resnet18
        # [0.485, 0.456, 0.406] --> normalized mean value of ImageNet, [0.229, 0.224, 0.225] std of ImageNet
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
    def getVec(self, img):
        image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
        embedding = torch.zeros(1, self.numberFeatures, 1, 1)

        def copyData(m, i, o): embedding.copy_(o.data)

        h = self.featureLayer.register_forward_hook(copyData)
        self.model(image)
        h.remove()

        return embedding.numpy()[0, :, 0, 0]

    def getFeatureLayer(self):
        cnnModel = models.resnet18(pretrained=True)
        layer = cnnModel._modules.get('avgpool')
        self.layer_output_size = 512
        
        return cnnModel, layer
        

# generate vectors for all the images in the set
img2vec = Img2VecResnet18() 

allVectors = {}
print("Converting images to feature vectors:")
for image in tqdm(os.listdir(inputDirCNN)):
    I = Image.open(os.path.join(inputDirCNN, image)).convert("RGB")
    vec = img2vec.getVec(I)
    allVectors[image] = vec
    I.close() 


In [None]:
# now let us define a function that calculates the cosine similarity entries in the similarity matrix
def getSimilarityMatrix(vectors):
    v = np.array(list(vectors.values())).T
    sim = np.inner(v.T, v.T) / ((np.linalg.norm(v, axis=0).reshape(-1,1)) * ((np.linalg.norm(v, axis=0).reshape(-1,1)).T))
    keys = list(vectors.keys())
    matrix = pd.DataFrame(sim, columns = keys, index = keys)
    
    return matrix
        
similarityMatrix = getSimilarityMatrix(allVectors)

In [None]:
# the number of top similar images to be stored
k = 5

similarNames = pd.DataFrame(index = similarityMatrix.index, columns = range(k))
similarValues = pd.DataFrame(index = similarityMatrix.index, columns = range(k))

for j in tqdm(range(similarityMatrix.shape[0])):
    kSimilar = similarityMatrix.iloc[j, :].sort_values(ascending = False).head(k)
    similarNames.iloc[j, :] = list(kSimilar.index)
    similarValues.iloc[j, :] = kSimilar.values
similarNames_path = "similarNames.pkl"
similarValues_path = "similarValues.pkl"
similarNames.to_pickle(similarNames_path)
similarValues.to_pickle(similarValues_path)

In [None]:
# open a file, where you stored the pickled data
file = open(similarNames_path, 'rb')
simNames = pickle.load(file)
file.close()

file = open(similarValues_path, 'rb')
simValues = pickle.load(file)
file.close()

In [None]:
def setAxes(ax, image, query = False, **kwargs):
    value = kwargs.get("value", None)
    if query:
        ax.set_xlabel("Query Image\n{0}".format(image), fontsize = 8)
    else:
        ax.set_xlabel("Similarity value {1:1.3f}\n{0}".format( image,  value), fontsize = 8)
    ax.set_xticks([])
    ax.set_yticks([])
    
def getSimilarImages(image, simNames, simVals):
    cutoff_value = 0.93
    filtered = simVals[simVals > cutoff_value]
    if image in set(simNames.index):
        imgs = list(simNames.loc[image, :])
        vals = list(filtered.loc[image, :])
        if image in imgs:
            assert_almost_equal(max(vals), 1, decimal = 5)
            imgs.remove(image)
            vals.remove(max(vals))
        return imgs, vals
    else:
        print("'{}' Unknown image".format(image))
        
def plotSimilarImages(image, similarNames, similarValues):
    simImages, simValues = getSimilarImages(image, similarNames, similarValues)
    fig = plt.figure(figsize=(10, 20))
    
    # cut-off value, which determines how similar images we want
    cutoff_value = 0.93
    whole_data = []

    # now plot the  most simliar images
    for j in range(0, numCol*numRow):
        ax = []
        if j == 0:
            img = Image.open(os.path.join(inputDir, image))
            ax = fig.add_subplot(numRow, numCol, 1)
            setAxes(ax, image, query = True)
        else:
            # If not accurate enough, e.g,smaller than cutoff, dont print the image or write data 
            if simValues[j-1] < cutoff_value:
                continue
            img = Image.open(os.path.join(inputDir, simImages[j-1]))
            ax.append(fig.add_subplot(numRow, numCol, j+1))
            setAxes(ax[-1], simImages[j-1], value = simValues[j-1])
            Dict = {"original_image" : image, "similar_image": simImages[j-1], "similarity_score": simValues[j-1]}
            whole_data.append(Dict)
        
        img = img.convert('RGB')
        plt.imshow(img)
        img.close()

    plt.show()
    return whole_data
        

In [None]:
def write_image_data_to_csv(list_of_dicts):
    with open('microbio_similarities.csv', 'a') as csvfile:
        field_names = ["original_image", "similar_image", "similarity_score"]
        writer = csv.DictWriter(csvfile, fieldnames=field_names)
        writer.writeheader()
        for i in list_of_dicts:
            writer.writerow(i)

In [None]:
def process_data(results_list_of_dicts):
    pass
    # data = pd.DataFrame.from_dict(results_list_of_dicts, index)
    # print(data)

In [None]:
import matplotlib.pyplot as plt
import random

# take three examples from the provided image set and plot
folder_path = "ill_copy_new_clip/microbio"
numCol = 5
numRow = 1
num_files = 1000
all_files = os.listdir(folder_path)
# Shuffle the list of files randomly
random.shuffle(all_files)
selected_files = all_files[:num_files]

results = list()

for image in selected_files:
    imgs, vals = getSimilarImages(image, simNames, simValues)
    for x in range(0, len(imgs)):
        if pd.isna(vals[x]):
            continue
        Dict = {"original_image" : image, "similar_image" : imgs[x], "similarity_score" : vals[x]}
        results.append(Dict)

write_image_data_to_csv(results)

process_data(results)
    

In [None]:
similarNames