In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import tensorflow as tf

import keras.models as KM
import keras.applications as KA
import keras.layers as KL


Using TensorFlow backend.


In [2]:
# Paths of images folders
PATH_BG = "..\\data\\bg\\"
PATH_DOG1 = "..\\data\\dog1\\"

# Images parameters for network feeding
IM_H = 224
IM_W = 224
IM_C = 3

# Training parameters:
EPOCHS = 1
BATCH_SIZE = 32

# Embedding size
EMB_SIZE = 128

In [3]:
# Retrieve filenames
filenames_bg = []
for file in os.listdir(PATH_BG):
        if ".jpg" in file:
                filenames_bg += [file]

filenames_dog1 = []
for file in os.listdir(PATH_DOG1):
        if ".jpg" in file:
                filenames_dog1 += [file]

# Opens an image file, stores it into a tf.Tensor and reshapes it
def _parse_function(filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_jpeg(image_string, channels=3)
        image_resized = tf.image.resize_images(image_decoded, [IM_H, IM_W])
        return image_resized, label

filenames = np.append(
        [PATH_DOG1 + filenames_dog1[i] for i in range(len(filenames_dog1))],
        [PATH_BG + filenames_bg[i] for i in range(len(filenames_bg))],
        axis=0
        )
labels = np.append(np.ones(len(filenames_dog1)), np.arange(2,2+len(filenames_bg)))

# Filenames and labels place holder
filenames_placeholder = tf.placeholder(filenames.dtype, filenames.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

# Defining dataset
dataset = tf.data.Dataset.from_tensor_slices((filenames_placeholder, (filenames_placeholder, labels_placeholder)))
dataset = dataset.map(_parse_function)

# Batch the dataset for training
data_train = dataset.repeat(EPOCHS).batch(BATCH_SIZE)
iterator = data_train.make_initializable_iterator()
next_element = iterator.get_next()

# Define the dataset for loss computation
# embedding_placeholder = tf.placeholder(tf.float32, shape=(None, EMB_SIZE))
# data_loss = tf.data.Dataset.from_tensor_slices((filenames_placeholder, (labels_placeholder, embedding_placeholder)))

# Defining the global prediction dictionary
data_dict = np.concatenate(
        (np.expand_dims(filenames,1), np.expand_dims(labels,1)),
        axis = 1
        )

data_pred = {filename:(label, np.random.random_sample((EMB_SIZE,))) for filename, label in data_dict}


In [4]:
# Build the model using Keras pretrained model NASNetMobile,
# a light and efficient network
def NASNet_embedding(
        input_tensor,
        input_shape=(224,224,3),
        include_top=False,
        training=True
        ):

#         base_model = KA.NASNetMobile(
#                 input_tensor=input_tensor,
#                 input_shape=input_shape,
#                 include_top=False
#                 )
#         x = KL.GlobalAveragePooling2D()(base_model.output)
        x = KL.GlobalAveragePooling2D()(input_tensor)
        x = KL.Dense(1056, activation='relu')(x)
        if training:
                x = KL.Dropout(0.5)(x)
        x = KL.Dense(EMB_SIZE)(x)
        x = tf.keras.backend.l2_normalize(x)

        return x


In [5]:
def deviation_loss(dict_pred):
        sum_class_loss = 0
        classes_loss = 0

        class_pred = {}

        # Compute all center of mass
        for _, (label, pred) in dict_pred.items():
                if label in class_pred.keys():
                        class_pred[label][0] += pred
                        class_pred[label][1] += 1
                else:
                        class_pred[label] = [pred,1]
        for label in class_pred:
                class_pred[label][0] /= class_pred[label][1]
        
        # Compute all classes center of mass
        class_pred_values = np.array(list(class_pred.values()))[:,0]
        print(class_pred_values.shape)
        classes_center = np.repeat(np.expand_dims(np.sum(class_pred_values)/len(class_pred),axis=0), len(class_pred_values), axis=0)
        classes_loss -= np.sum(np.log(np.linalg.norm(class_pred_values - classes_center)))
        
        # Compute 
        for _, (label, pred) in dict_pred.items():
                sum_class_loss += np.linalg.norm(pred - class_pred[label])

        return classes_loss + sum_class_loss

In [43]:
val, (file, lab) = next_element

y_pred = NASNet_embedding(val)

sess = tf.Session()

init = tf.global_variables_initializer()
sess.run(iterator.initializer, feed_dict={filenames_placeholder:filenames, labels_placeholder:labels})
sess.run(init)


val_files, val_preds = sess.run([file, y_pred])
for i in range(len(val_files)):
    #print("before : ")
    file = val_files[i].decode('utf-8')
    #print(data_pred[file])
    data_pred[file] = (data_pred[file][0], val_preds[i])
    #print("after : ")
    #print(data_pred[file])

loss = deviation_loss(data_pred)

(66,)


ValueError: operands could not be broadcast together with shapes (66,) (66,128) 

In [17]:
val_files[0].decode('utf-8')

'..\\data\\dog1\\119228435530689461.jpg'

In [19]:
for i in range(len(val_files)):
    print("before : ")
    file = val_files[i].decode('utf-8')
    print(data_pred[file])
    data_pred[file] = (data_pred[file][0], val_preds[i])
    print("after : ")
    print(data_pred[file])

before : 
('1.0', array([0.25012738, 0.24885891, 0.96956038, 0.90026831, 0.27023706,
       0.77984417, 0.26399032, 0.12644127, 0.15092847, 0.6904359 ,
       0.29201641, 0.15648301, 0.53832075, 0.85924683, 0.85835124,
       0.82535647, 0.70920617, 0.26731372, 0.41331575, 0.47661387,
       0.8551501 , 0.51977242, 0.91808642, 0.25930064, 0.84044537,
       0.39411997, 0.73726828, 0.3464977 , 0.25610745, 0.62179136,
       0.35924839, 0.95053404, 0.72320419, 0.94814901, 0.05814724,
       0.23112463, 0.45438016, 0.54872613, 0.69770205, 0.80200806,
       0.73515688, 0.25327283, 0.05395533, 0.02021352, 0.83407735,
       0.0957621 , 0.14531702, 0.13140698, 0.42633016, 0.69484677,
       0.99934418, 0.75904054, 0.31104001, 0.22451297, 0.61757607,
       0.06336793, 0.04285033, 0.63337917, 0.73536061, 0.78395846,
       0.62235765, 0.55021516, 0.93927818, 0.81305537, 0.45715411,
       0.6328232 , 0.41297764, 0.81977977, 0.31591519, 0.53549885,
       0.0366196 , 0.35520654, 0.63984351, 0

       0.51684629, 0.40195069, 0.86038035]))
after : 
('1.0', array([-0.00950272,  0.00609177,  0.01646642, -0.01674177,  0.01632488,
       -0.00695133, -0.00913372,  0.01942851, -0.00175399,  0.0273061 ,
       -0.02199635,  0.00942165,  0.01734877,  0.0114172 , -0.00273204,
       -0.01744892, -0.00182343,  0.00232514, -0.01042074, -0.00729368,
       -0.01225919, -0.01505204, -0.00358568, -0.01024586, -0.00139441,
        0.01278868,  0.02622678, -0.03628683,  0.0140245 ,  0.0017754 ,
       -0.01337241,  0.0119019 ,  0.03071779, -0.00576574, -0.00380787,
       -0.00204205,  0.01630234, -0.05337519,  0.00750745, -0.01649908,
        0.00070701, -0.01314553, -0.00223038,  0.02371406,  0.01181231,
        0.00166175,  0.03256505,  0.01524732, -0.0127828 , -0.01068392,
        0.00654865, -0.015254  , -0.00912557, -0.00672644,  0.00994674,
       -0.00174074, -0.02761319, -0.01855658,  0.00321746,  0.02406985,
       -0.01967716,  0.01544928,  0.00021859,  0.02460167,  0.00952335,
  

        0.00512017,  0.02330386,  0.00381929], dtype=float32))
before : 
('1.0', array([0.7198685 , 0.67274607, 0.39376606, 0.17088296, 0.3258542 ,
       0.2046301 , 0.37035631, 0.08360406, 0.10305615, 0.57940685,
       0.5279583 , 0.65758817, 0.30892024, 0.36475435, 0.53072071,
       0.78087683, 0.81562131, 0.87128455, 0.96391974, 0.91633364,
       0.54653555, 0.26788093, 0.72613839, 0.22004702, 0.77661439,
       0.6176667 , 0.39805137, 0.01519703, 0.82740938, 0.41560716,
       0.32412885, 0.03014056, 0.88750422, 0.68332857, 0.42810367,
       0.05667157, 0.34653202, 0.06554401, 0.95764684, 0.04572643,
       0.52070589, 0.80508881, 0.01510325, 0.67662992, 0.32284808,
       0.35824097, 0.89316347, 0.4746951 , 0.03130319, 0.25275483,
       0.29867674, 0.56954553, 0.39304838, 0.65578597, 0.26703929,
       0.62965794, 0.90759197, 0.23024821, 0.09713386, 0.15965621,
       0.80080469, 0.29080154, 0.19002638, 0.84824453, 0.63049832,
       0.85499408, 0.41308616, 0.89324551, 0.003

       0.34445433, 0.45520093, 0.87406353]))
after : 
('1.0', array([-0.00914637,  0.00575839,  0.01567203, -0.0159306 ,  0.01559186,
       -0.00653622, -0.00876984,  0.01860082, -0.0016741 ,  0.0258464 ,
       -0.02088054,  0.00892373,  0.01640606,  0.01093807, -0.00261302,
       -0.01664357, -0.00165064,  0.00221247, -0.00978579, -0.00692487,
       -0.01166977, -0.01427267, -0.00346479, -0.00962426, -0.00147121,
        0.01222958,  0.02504184, -0.0344038 ,  0.01329848,  0.00186451,
       -0.0127817 ,  0.01124926,  0.02912594, -0.00546647, -0.0036491 ,
       -0.00186186,  0.01556623, -0.05091332,  0.00710624, -0.01567571,
        0.00067073, -0.01257027, -0.00193689,  0.02255868,  0.01114367,
        0.00163986,  0.03090657,  0.01450926, -0.01208931, -0.01015305,
        0.00622557, -0.01463626, -0.00857351, -0.00653607,  0.00948445,
       -0.00168146, -0.02634569, -0.01768508,  0.00283919,  0.02280748,
       -0.01889466,  0.01458941,  0.00035236,  0.0234103 ,  0.0091496 ,
  

        0.00502102,  0.02105835,  0.00382385], dtype=float32))
before : 
('1.0', array([0.25459542, 0.05003346, 0.30705807, 0.8690899 , 0.96085117,
       0.05004076, 0.73266468, 0.19795184, 0.79312382, 0.77226971,
       0.98266347, 0.55696747, 0.99746366, 0.48415895, 0.94237366,
       0.7578336 , 0.6638826 , 0.32817946, 0.04908887, 0.38108227,
       0.33072929, 0.45632905, 0.42932598, 0.14616045, 0.73020957,
       0.06165956, 0.54710999, 0.56269915, 0.43997115, 0.68505602,
       0.06039864, 0.79896154, 0.1002646 , 0.73048124, 0.82848071,
       0.69237064, 0.94798373, 0.27665437, 0.83655159, 0.77203391,
       0.3010797 , 0.74105657, 0.2102486 , 0.40935968, 0.77096679,
       0.96954607, 0.75763613, 0.6898784 , 0.29888283, 0.38573261,
       0.67106865, 0.71934478, 0.68213299, 0.56211247, 0.81514007,
       0.83597602, 0.69229057, 0.07320846, 0.38369034, 0.74532804,
       0.03099968, 0.98585999, 0.73194818, 0.88151477, 0.28084537,
       0.10704505, 0.4846904 , 0.86738746, 0.469

      dtype=float32))
before : 
('1.0', array([0.34733329, 0.42118938, 0.97345546, 0.64301851, 0.05962162,
       0.98712547, 0.99870789, 0.34653442, 0.04167232, 0.88773954,
       0.09522177, 0.11380115, 0.01341888, 0.28345928, 0.34421826,
       0.36356954, 0.42824632, 0.28361202, 0.57586114, 0.81466786,
       0.11370369, 0.85871106, 0.48551664, 0.06927496, 0.0799738 ,
       0.54502795, 0.52246743, 0.42883319, 0.40707907, 0.0243348 ,
       0.90137664, 0.5749046 , 0.3994565 , 0.75654913, 0.47968676,
       0.12275885, 0.61324393, 0.58204562, 0.3459481 , 0.26132332,
       0.58075492, 0.3010486 , 0.03196723, 0.18487355, 0.38317759,
       0.38193465, 0.83889775, 0.19482864, 0.24086636, 0.80399258,
       0.84870677, 0.00352486, 0.46822415, 0.37570086, 0.21921747,
       0.7135175 , 0.62452322, 0.28250252, 0.69590278, 0.13844164,
       0.5851351 , 0.2723663 , 0.92448588, 0.43509254, 0.48335192,
       0.12099041, 0.81227894, 0.12839078, 0.40261286, 0.20361123,
       0.23079212, 0.7

       0.12608691, 0.7779555 , 0.89127021]))
after : 
('1.0', array([-1.07180728e-02,  6.11219741e-03,  1.83672607e-02, -1.81936640e-02,
        1.81236640e-02, -7.18890782e-03, -1.02144564e-02,  2.20642220e-02,
       -1.67274079e-03,  2.94935945e-02, -2.42251102e-02,  1.01923477e-02,
        1.88483894e-02,  1.33477673e-02, -2.87920516e-03, -1.92957539e-02,
       -1.58257654e-03,  2.87638232e-03, -1.10406475e-02, -7.81176705e-03,
       -1.38944602e-02, -1.61680058e-02, -4.25991509e-03, -1.07119270e-02,
       -2.23330990e-03,  1.41556179e-02,  2.89944224e-02, -3.93311195e-02,
        1.48047348e-02,  2.86020059e-03, -1.48208411e-02,  1.27563141e-02,
        3.34714353e-02, -6.25505764e-03, -4.24763933e-03, -1.94460619e-03,
        1.82755515e-02, -5.90707771e-02,  8.04895721e-03, -1.79522056e-02,
        8.54687125e-04, -1.46898143e-02, -1.70022156e-03,  2.60622352e-02,
        1.25294402e-02,  2.34844093e-03,  3.54142301e-02,  1.65665429e-02,
       -1.35935321e-02, -1.15313781e-0

      dtype=float32))
before : 
('1.0', array([0.46798432, 0.35938124, 0.95669683, 0.57494064, 0.76961638,
       0.47851977, 0.06541619, 0.76359762, 0.50530591, 0.99551732,
       0.5316194 , 0.4835379 , 0.28646096, 0.28047405, 0.85452217,
       0.81980684, 0.49701276, 0.34038699, 0.25281507, 0.8349989 ,
       0.66645543, 0.94612646, 0.69467864, 0.87753856, 0.3608665 ,
       0.23967133, 0.34955525, 0.16724892, 0.38219667, 0.93133687,
       0.39863124, 0.73622107, 0.8477717 , 0.47590913, 0.34502612,
       0.61586412, 0.81290771, 0.56826454, 0.46512162, 0.42259549,
       0.98539706, 0.38736754, 0.65064121, 0.2282398 , 0.7538174 ,
       0.63677174, 0.9208856 , 0.1355221 , 0.29683201, 0.48674053,
       0.96643864, 0.12202254, 0.92557872, 0.30297404, 0.77429813,
       0.17121633, 0.34213312, 0.103404  , 0.55016056, 0.42901235,
       0.0900545 , 0.12916497, 0.27403938, 0.41543108, 0.64890643,
       0.47761564, 0.07796036, 0.48538898, 0.22212654, 0.37360371,
       0.92094176, 0.9

       2.65409319e-01, 7.80364589e-01, 6.71289658e-01, 7.06970645e-01]))
after : 
('2.0', array([-0.0125055 ,  0.00607879,  0.02125219, -0.02036234,  0.02075234,
       -0.00749137, -0.01189524,  0.02609024, -0.00155047,  0.03278763,
       -0.02762022,  0.01140945,  0.02091512,  0.01626671, -0.00297822,
       -0.02195091, -0.00117697,  0.00372192, -0.01189269, -0.00857123,
       -0.01631916, -0.01783397, -0.00528243, -0.01147532, -0.00361722,
        0.0162071 ,  0.03304283, -0.04371183,  0.01596826,  0.00453897,
       -0.01703609,  0.0140115 ,  0.03749125, -0.00688559, -0.00490468,
       -0.00174211,  0.02121546, -0.06749444,  0.00887717, -0.02009082,
        0.00109673, -0.01685604, -0.00077345,  0.02949115,  0.01351183,
        0.00353538,  0.03960128,  0.01855611, -0.01474593, -0.0128761 ,
        0.00825902, -0.02056512, -0.00938668, -0.00976934,  0.01293157,
       -0.00215885, -0.03501681, -0.02366026,  0.00167631,  0.02901613,
       -0.02538836,  0.01822593,  0.0018268 , 

      dtype=float32))
before : 
('5.0', array([0.21357187, 0.3939227 , 0.8631511 , 0.69411524, 0.81526718,
       0.77720767, 0.69363026, 0.88159414, 0.67060582, 0.34161964,
       0.891202  , 0.06932831, 0.0260438 , 0.14968564, 0.66894181,
       0.92050269, 0.85574255, 0.97613488, 0.92964762, 0.29815735,
       0.45124094, 0.32345281, 0.41716955, 0.56436232, 0.74244824,
       0.19728229, 0.1769458 , 0.78792288, 0.74113919, 0.27678176,
       0.23569259, 0.44230224, 0.12912975, 0.78664283, 0.34525046,
       0.6179223 , 0.07580547, 0.57517177, 0.18793325, 0.21645012,
       0.5693583 , 0.71460381, 0.39883443, 0.1235306 , 0.83562269,
       0.84232699, 0.87910727, 0.9077286 , 0.43657914, 0.81338226,
       0.36122996, 0.33465474, 0.38015553, 0.51010685, 0.84964175,
       0.58440289, 0.90244096, 0.29672549, 0.13649112, 0.4427511 ,
       0.12261992, 0.16823719, 0.60458268, 0.58201896, 0.37806314,
       0.17777477, 0.18159735, 0.44250873, 0.68882215, 0.52943584,
       0.09927454, 0.1

       0.58034804, 0.54708547, 0.20441861]))
after : 
('7.0', array([-0.00929055,  0.00431673,  0.01624063, -0.01521212,  0.01559406,
       -0.00560992, -0.00882038,  0.01983684, -0.00094351,  0.02490051,
       -0.02105162,  0.00866752,  0.01603832,  0.01263308, -0.00208946,
       -0.01650419, -0.0008782 ,  0.00305754, -0.00916059, -0.00638425,
       -0.01261812, -0.01339118, -0.00401354, -0.00869874, -0.00275775,
        0.01205995,  0.0247036 , -0.03301015,  0.01177008,  0.00351504,
       -0.01275339,  0.01056843,  0.02844503, -0.00522887, -0.00366275,
       -0.00139189,  0.01603387, -0.05077695,  0.00665348, -0.01514576,
        0.00092232, -0.01265242, -0.00064923,  0.02225788,  0.0101895 ,
        0.00289658,  0.02989403,  0.01384124, -0.0110128 , -0.0095522 ,
        0.00625641, -0.01552833, -0.00688604, -0.00731014,  0.00982527,
       -0.00151722, -0.02633349, -0.01786296,  0.0014343 ,  0.02193334,
       -0.01887827,  0.01381877,  0.00134841,  0.02279066,  0.01027733,
  

        0.00845034,  0.02174419,  0.00763362], dtype=float32))
before : 
('10.0', array([0.5911337 , 0.57426541, 0.85482709, 0.23590835, 0.8695435 ,
       0.71198846, 0.01350009, 0.10088431, 0.1529003 , 0.55393716,
       0.84904831, 0.42567254, 0.78614796, 0.2349456 , 0.41689332,
       0.37178844, 0.12214055, 0.63775942, 0.09554921, 0.84005994,
       0.46009365, 0.35843917, 0.66499501, 0.17938476, 0.11787782,
       0.67914288, 0.59912131, 0.51052559, 0.16840131, 0.99645503,
       0.90572958, 0.0830444 , 0.14566295, 0.11125399, 0.20838553,
       0.75108414, 0.23807884, 0.44730085, 0.19345844, 0.26315903,
       0.11821898, 0.05831473, 0.92002101, 0.43640059, 0.10187064,
       0.90635298, 0.04471829, 0.38493481, 0.29539333, 0.19503657,
       0.66929071, 0.41731752, 0.03001857, 0.29110465, 0.7562142 ,
       0.53152169, 0.97421309, 0.12716238, 0.8786555 , 0.77257559,
       0.9044795 , 0.9356216 , 0.87026373, 0.40131475, 0.13600186,
       0.6603935 , 0.49138303, 0.98947421, 0.71

       0.08254918, 0.2199163 , 0.89692041]))
after : 
('12.0', array([-0.00759   ,  0.00204735,  0.01756536, -0.01345772,  0.01452428,
       -0.00538955, -0.00746297,  0.0202185 ,  0.00082424,  0.02565296,
       -0.02219618,  0.00917377,  0.01779434,  0.01466288, -0.00093212,
       -0.01564407, -0.00094516,  0.00507021, -0.01073451, -0.00538207,
       -0.01455421, -0.01302181, -0.00373802, -0.00896272, -0.00270772,
        0.01034436,  0.02170912, -0.0322342 ,  0.00957467,  0.00387719,
       -0.01144258,  0.0104118 ,  0.02902104, -0.00539042, -0.00335153,
       -0.001998  ,  0.0158519 , -0.04821377,  0.00626874, -0.01461715,
        0.00171398, -0.01191165, -0.00114795,  0.02176754,  0.01029605,
        0.00437415,  0.02917097,  0.01209404, -0.01004123, -0.00810235,
        0.00656306, -0.01537377, -0.00533521, -0.00663826,  0.00983246,
       -0.00029012, -0.02513484, -0.01738575,  0.00302768,  0.02198418,
       -0.01573218,  0.01405392,  0.00089702,  0.02087701,  0.01122938,
 

        0.00038693,  0.01818866, -0.00099669], dtype=float32))


In [27]:
d={1:(2,3), 3:(5,8)}
for a, (b,c) in d.items():
    print(a)
    print(b)
    print(c)
print(list(d.values()))

1
2
3
3
5
8
[(2, 3), (5, 8)]


In [39]:
x =  [[1,2]]
np.repeat(x, 2, axis=0)

array([[1, 2],
       [1, 2]])