In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
from glob import glob
import cv2
import numpy as np
import pandas as pd
from random import shuffle
from tqdm import tqdm
import pickle

from keras.models import load_model
from keras.models import Model

def snip_model(m, layer_name = "global_average_pooling2d_1"):
  return Model(inputs=m.inputs, outputs=m.get_layer(layer_name).output)

def resize_image(im_array, resolution = (32,32)):
    return cv2.resize(im_array, resolution)

def get_files_list(folder):
  return sorted(glob(folder))

def cosine_similarity(a,b):
  return np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))

def get_class(file_name):
  for k, v in class_id.items():
    if v in file_name:
      return k

def get_bw(clr_img):
    return np.expand_dims(cv2.cvtColor(clr_img, cv2.COLOR_BGR2GRAY), axis=-1)

def get_file_features(model, fl, read = False, gray = False, resize = False):
  if read: fl = cv2.imread(fl)
  if gray: fl = get_bw(fl)
  if resize: fl = resize_image(fl)
  fl = fl/255
  fl = np.expand_dims(fl, axis = 0)
  pred = model.predict(fl)
  return np.squeeze(pred)

def pred_img(model, img, gray = False, resize = False):
  if gray: img = get_bw(img)
  if resize: img = resize_image(img)
  img = np.expand_dims(img, axis = 0)
  img = model.predict(img)
  return np.squeeze(img)
  
def load_cifar():
  from keras.datasets import cifar10
  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  x_train = x_train.astype('float32')
  x_test = x_test.astype('float32')
  x_train /= 255
  x_test /= 255
  return x_train, y_train, x_test, y_test

Using TensorFlow backend.


In [3]:
classed_id = {'airplane' : 0, 'automobile' : 1, 'bird' : 2, 'cat' : 3, 'deer' : 4, 'dog' : 5, 'frog' : 6, 'horse' : 7, 'ship' : 8, 'truck' : 9}

clr_model = load_model('/content/drive/My Drive/CIFAR_clr.h5')
gs_model = load_model('/content/drive/My Drive/CIFAR_gs.h5')














Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




In [0]:
clr_snip = snip_model(clr_model, layer_name = "global_average_pooling2d_1")
gs_snip = snip_model(gs_model, "global_average_pooling2d_2")

In [0]:
def get_bw_data(x_train):
  return np.array([get_bw(i) for i in x_train])

x_train, y_train, x_test, y_test = load_cifar()
x_train_g, x_test_g = get_bw_data(x_train), get_bw_data(x_test)


In [0]:
def test_models():
  Xt = x_test
  yt = y_test
  cm = clr_model
  cms = clr_snip
  gm = gs_model
  gms = gs_snip
  
  test_set = [(x,y) for x, y in zip(Xt,yt)]

  results_dict = dict()
  for idx, (image_array, label) in tqdm(enumerate(test_set)):
    results_dict[idx] = {
        "img_array": image_array,
        "actual_label": np.squeeze(label),

        "clr_label": np.argmax(pred_img(cm, image_array, gray=False)),
        "clr_features": pred_img(cms, image_array, gray=False),

        "gs_label": np.argmax(pred_img(gm, image_array, gray=True)),
        "gs_features": pred_img(gms, image_array, gray=True)
    }
  return results_dict

In [13]:
results = test_models()
with open("clr_gs_results_full.pkl", "wb") as f:
  pickle.dump(results, f)

10000it [06:45, 24.67it/s]


In [0]:
random_indexs = np.random.random_integers(0,10000, 100)

  """Entry point for launching an IPython kernel.


In [0]:
test_set = dict()
for idx, i in enumerate(random_indexs):
  test_set[idx] = results[i]

In [0]:
with open("clr_gs_test.pkl", "wb") as f:
  pickle.dump(test_set, f)

In [46]:
np.argmax(clr_model.predict(np.expand_dims(x_test[1], axis=0)))

8

In [44]:
y_test[1]

array([8])

In [47]:
test_set

{0: {'actual_label': array(9),
  'clr_features': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
  'clr_label': 3,
  'gs_features': array([0.        , 0.74737567, 0.        , ..., 0.        , 0.        ,
         0.        ], dtype=float32),
  'gs_label': 0,
  'img_array': array([[[0.61960787, 0.654902  , 0.67058825],
          [0.91764706, 0.9529412 , 0.96862745],
          [0.9372549 , 0.972549  , 0.9843137 ],
          ...,
          [0.7647059 , 0.85882354, 0.8745098 ],
          [0.65882355, 0.75686276, 0.77254903],
          [0.6901961 , 0.78039217, 0.8039216 ]],
  
         [[0.60784316, 0.6431373 , 0.6627451 ],
          [0.89411765, 0.92941177, 0.9490196 ],
          [0.8901961 , 0.9254902 , 0.94509804],
          ...,
          [0.827451  , 0.9019608 , 0.92941177],
          [0.73333335, 0.81960785, 0.84705883],
          [0.5647059 , 0.64705884, 0.6862745 ]],
  
         [[0.627451  , 0.6627451 , 0.6784314 ],
          [0.9137255 , 0.94509804, 0.9647059 ],
          [0.