# 0 - Install required packages

In [None]:
image, google_images_download, faiss

# 1 - Download the datasets

First we will get a table of year, brand and makes of cars

In [None]:
%%bash
rm -Rf dataset
mkdir -pv dataset
git clone git@github.com:arthurkao/vehicle-make-model-data.git dataset/

In [None]:
import pandas as pd

car_make_data = pd.read_csv('dataset/csv_data.csv')

car_make_data.head()

In [None]:
# we build the queries we will use to search google images
def cols_to_query(row):
    return ' '.join([str(row['year']), row['make'], row['model']]).lower()
                     
car_make_data['queries'] = car_make_data.apply(cols_to_query, axis=1)

In [None]:
# now we retrieve images from google images using the 
# `google_images_download` python package
from time import time
from google_images_download import google_images_download  

client = google_images_download.googleimagesdownload()

n_samples = 10
n_images_per_sample = 100

sample_queries = car_make_data['queries'].sample(n_samples, random_state=42)

for (i, (_, query)) in enumerate(sample_queries.iteritems()):
    arguments = {
        "keywords": query,
        "limit": n_images_per_sample,
        "size": "medium",
        "format": "png",
        "output_directory": 'dataset/'
    }
    start_time = time()
    paths, n_downloaded = client.download(arguments)
    
    print('Downloaded {} images for `{}` ({}/{}) in {:.2f}s'.format(
        n_downloaded, query, i, n_samples, time() - start_time))

In [None]:
import os

car_class_to_paths = {}
for folder in os.listdir('dataset/'):
    folder_path = os.path.join('dataset', folder)
    if os.path.isdir(folder_path):
        print(folder)
        car_class_to_paths[folder] = []
        for file in os.listdir(folder_path):
            car_class_to_paths[folder].append(
                os.path.join(folder_path, file))

car_class_to_paths

# 2 - Generate image representations

In [None]:
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import numpy as np

model = VGG16(weights='imagenet', include_top=False, pooling='avg')

In [None]:
n_images = len([v for vs in car_class_to_paths.values() for v in vs])
features = []
labels = []
paths = []
i = 0
max_samples = 500
for (label, label_paths) in car_class_to_paths.items():
    print(label)
    for label_path in label_paths:
        print(label_path)
        try:
            img = image.load_img(label_path, target_size=(224, 224))
        except:
            pass
        else:
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)
            features.append(model.predict(x))
            labels.append(label)
            paths.append(label_path)
            i += 1
            if i > max_samples:
                break
                break

features_arr = np.concatenate(features, axis=0)
print('Representations computed for {} images'.format(i))

# 3 - Index into an approx Nearest Neighbor structure

In [None]:
import faiss  
import tempfile
import urllib

import matplotlib.pyplot as plt
import urllib.request


def display_from_url(url, ax=plt.gca()):
    with urllib.request.urlopen(url) as response:
        img = plt.imread(response, 0)
        return ax.imshow(img)


def display_from_path(path, ax=plt.gca()):
    with open(path, 'rb') as f:
        img = plt.imread(f, 0)
        return ax.imshow(img)

In [None]:
class NearestNeighborsIndex:
    def __init__(self, representations, paths, model, model_input_size=(224, 224)):      
        self.index = faiss.IndexFlatL2(representations.shape[1])
        self.index.add(representations.astype(np.float32))
        print("{}/{} documents indexed".format(self.index.ntotal, 
                                               representations.shape[0]))
        self.paths = paths
        self.model = model
        self.model_input_size = model_input_size
        
    def _preprocess_and_predict(self, img):
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        return self.model.predict(x)
    
    def search(self, x, k=5):
        d, ixs = self.index.search(x, k)
        paths = [self.paths[ix] for ix in ixs[0]]
        return d, paths
    
    def search_from_path(self, path, k=5):
        img = image.load_img(path, target_size=self.model_input_size)
        x = self._preprocess_and_predict(img)
        return self.search(x, k)
        
    def search_from_url(self, url, k=5):
        with tempfile.NamedTemporaryFile() as f:
            urllib.request.urlretrieve(url, f.name)
            img = image.load_img(f.name, target_size=self.model_input_size)
            x = self._preprocess_and_predict(img)
            return self.search(x, k)


In [None]:
nn = NearestNeighborsIndex(features_arr, paths, model)

In [None]:
def search_and_plot_from_url(url, k=5):
    _, neighbor_paths = nn.search_from_url(url, k)
    f, axs = plt.subplots(1, k + 1, figsize=(3*k, 10))
    display_from_url(url, axs[0])
    axs[0].axis('off')
    for ax, neighbor_path in zip(axs[1:], neighbor_paths):
        display_from_path(neighbor_path, ax)
        ax.axis('off')
    return neighbor_paths

def search_and_plot_from_path(path, k=5):
    _, neighbor_paths = nn.search_from_path(path, k)
    f, axs = plt.subplots(1, k + 1, figsize=(3*k, 10))
    display_from_path(path, axs[0])
    for ax, neighbor_path in zip(axs[1:], neighbor_paths):
        display_from_path(neighbor_path, ax)  
    return neighbor_paths

In [None]:
search_and_plot_from_url('https://img.letgo.com/images/b3/9a/01/67/b39a0167d370e3a220982d94c99bceb0.jpeg?impolicy=img_600')

In [None]:
search_and_plot_from_path(paths[-1], 10)