In [2]:
import numpy as np
import pandas as pd
import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from matplotlib.pyplot import imshow

%matplotlib inline

#!pip install efficientnet_pytorch

import joblib
from efficientnet_pytorch import EfficientNet

In [None]:
ball_tree_dump_file = 'library_ball_tree.sav'
lib_files_dump_file = 'library_files_list.sav'
CNN_MODEL_WEIGHTS = Path('/kaggle/input/...')
QUERY_DIR = Path('/kaggle/input/...')

In [None]:
knn_model = joblib.load(ball_tree_dump_file)

cnn_model = EfficientNet.from_name('efficientnet-b1')
cnn_model.load_state_dict(torch.load(CNN_MODEL_WEIGHTS))
cnn_model.eval()

lib_files = joblib.load(ball_tree_dump_file)

In [None]:
#prepare model

In [None]:
def query_processing(knn_model, cnn_model, query_path, lib_files, n_results=10):
    image = Image.open(query_path)
    image.load()
    
    if (len(image.mode) < 2):
        image = transforms.Grayscale(3)(image)
    transforms.functional.adjust_saturation(img=image,saturation_factor=1.25)
    transforms.functional.adjust_gamma(img=image, gamma=0.25)
    
    image_transform = transforms.Compose([
            transforms.Resize(RESCALE_SIZE),
            transforms.CenterCrop(RESCALE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    image = image_transform(image)
    image.to(DEVICE)
    
    feature_vector = cnn_model(image).cpu()
    nearest_nbrs = knn_model(feature_vector, k=10, return_distance=False)
    
    result = list()
    for i in nearest_nbrs:
        result.append(lib_files[i])
    
    return result

In [None]:
query_files = sorted(list(QUERY_DIR.rglob('*.jpg'))
prepare_model(cnn_model)
query_results = list()
                     
for query_file in query_files:
    result = query_processing(knn_model, cnn_model, query_file, lib_files)
    query_results.append(result)

In [None]:
def vector_method3(model, vector_features, X_test, N_QUERY_RESULT=10):
    
    '''new_data - X_test without forward. Data goes through model and becomes a new vector
       which we compare with elements in vector_features'''
    
    new_data = model(X_test)
    new_data_np = np.array(new_data)
    
    vector_features_np = np.array([np.array(vector) for vector in vector_features])
    
    if vector_features_np.ndim == 1:
        vector_features_np = vector_features[:, np.newaxis]
      
    nbrs = NearestNeighbors(n_neighbors=N_QUERY_RESULT, metric="cosine").fit(vector_features)
    
    kneighbors = []
    for data in new_data_np:
        
        data = data[np.newaxis, :]
        
        distances, indices = nbrs.kneighbors(data)
        similar_image_indices = indices.reshape(-1)
        kneighbors.append(vector_features[similar_image_indices])
    
    return kneighbors

In [None]:
query_dataframe = pd.DataFrame(query_files)
results_dataframe = pd.DataFrame(query_results)
result_df = pd.concat([query_dataframe,results_dataframe]).to_csv()

In [None]:
def random_query_show(query_files, query_results):
    rand_idx = int(np.random.uniform(0,len(query_results)))
    print("Query image:")
    pil_im = Image.open(query_files[rand_idx], 'r')
    imshow(np.asarray(pil_im))
    print("Query result:")
    for img in query_results[rand_idx]:
        pil_im = Image.open(query_files[rand_idx], 'r')
        imshow(np.asarray(pil_im))

In [None]:
random_query_show(query_files, query_results)