<a href="https://colab.research.google.com/github/aviresler/antique-gen/blob/master/antique_pedia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install packages - run once




In [None]:
!pip install keras_efficientnets
!pip install keras_applications

Load model - run once

In [None]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
downloaded = drive.CreateFile({'id':"1oFKI98Xy6Apvu82HbfIbhRiWS4kVEwBO"})
downloaded.GetContentFile('images_600.zip')
downloaded = drive.CreateFile({'id':"1-5r_8H5LtSPvcTroXccSGLpulrHdT0RT"})
downloaded.GetContentFile('models.zip')
!unzip images_600.zip
!unzip models.zip
!rm images_600.zip
!rm models.zip

Prediction - can run multiple times.

-You will be asked to choose an image from your computer at the bottom of this cell.

-5 similar images with their label will be shown at the bottom of the cell.

In [None]:
from keras_efficientnets import EfficientNetB3
from keras.layers import Dense, Activation, GlobalAveragePooling2D, Input
from keras.models import Model
from keras.preprocessing import image
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import csv
from google.colab import files
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

MODEL_PATH = 'efficientNetB3_softmax_f_10-10-0.03_53.99.hdf5'
TRAINING_EMBEDDINGS = 'efficientNetB3_softmax_f_10__train9_53.99.csv'
TRAINING_LABELS = 'efficientNetB3_softmax_f_10__train9_53.99.tsv'
CLASSES_CSV_FILE = 'classes_top200.csv'


def build_efficientNet():
    base_model = EfficientNetB3((300, 300, 3),
                                include_top=False)
    inp = Input(shape=(300, 300, 3), name='main_input')
    x = base_model(inp)
    embeddings = GlobalAveragePooling2D(name='embeddings')(x)
    x = Dense(int(200), )(embeddings)
    out = Activation("softmax", name='out')(x)

    return Model(inputs=inp, outputs=[embeddings, out])

def preprocess_input(x):
    x /= 255.
    x -= 0.5
    x *= 2.
    return x

def test_model_on_query_imgaes(train_embeddings, train_labels, query_embeddings, classes_csv_file,class_mode = 'site_period'):

    N_neighbours = 5

    similaity_mat = cosine_similarity(query_embeddings, train_embeddings, dense_output=True)
    arg_sort_similaity = np.argsort(similaity_mat, axis=1)
    arg_sort_similaity = np.flip(arg_sort_similaity,axis =1)
    neighbours_ind = arg_sort_similaity[:,:N_neighbours]

    neighbours_cls = train_labels[neighbours_ind[0,:]]

    cnt = 0
    clasee_names = {}
    with open(classes_csv_file, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            if cnt > 0:
                if class_mode == 'site_period':
                    clasee_names[int(row[0])] = row[1]
                elif class_mode == 'period':
                    clasee_names[int(row[5])] = row[3]
                elif class_mode == 'site':
                    clasee_names[int(row[6])] = row[4]
                else:
                    raise
            cnt = cnt + 1

    train_files = []
    with open('train_file_names.csv', 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            train_files.append(row[0])

    dictt = {}
    for k in range(N_neighbours):
        url = 'images_600/train/' + train_files[arg_sort_similaity[0, k]]
        title = clasee_names[neighbours_cls[0]] + '_' + str(k)
        dictt[title] = url


    return dictt

def model_predict(img_path, model):

    img = image.load_img(img_path, target_size=(300, 300))

    # Preprocessing the image
    x = image.img_to_array(img)
    # x = np.true_divide(x, 255)
    x = np.expand_dims(x, axis=0)

    x = preprocess_input(x)

    # load training set embeddings
    train_embeddings = np.genfromtxt(TRAINING_EMBEDDINGS, delimiter=',')
    train_labels = np.genfromtxt(TRAINING_LABELS, delimiter='\t')

    query_embeddings = model.predict(x)
    query_embeddings = query_embeddings[0]

    summary = test_model_on_query_imgaes(train_embeddings, train_labels, query_embeddings, CLASSES_CSV_FILE, class_mode='site_period')

    return summary

model = build_efficientNet()
model.load_weights(MODEL_PATH)


data_to_load = files.upload()
for key, value in data_to_load.items():
  summ = model_predict(key, model)
  # print original image
  input_img = mpimg.imread(key)
  plt.figure(figsize=(10, 10))
  plt.imshow(input_img)
  plt.title('input image',fontsize=24)
  plt.axis('off')
  for label, path in summ.items():
      pred_img = mpimg.imread(path)
      plt.figure(figsize=(10, 10))
      plt.imshow(pred_img)
      plt.title(label,fontsize=24)
      plt.axis('off')

  os.remove(key)

