In [2]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input
import numpy as np
from numpy.linalg import norm   
from tqdm import tqdm
import pickle



In [3]:
tf.config.experimental.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [4]:
model =ResNet50(weights='imagenet',include_top=False,input_shape=(224,224,3))
model.trainable=False

In [5]:
model=tf.keras.Sequential([
    model,
    GlobalMaxPooling2D()
])

In [6]:
print(model.summary())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50 (Functional)        (None, 7, 7, 2048)        23587712  
_________________________________________________________________
global_max_pooling2d (Global (None, 2048)              0         
Total params: 23,587,712
Trainable params: 0
Non-trainable params: 23,587,712
_________________________________________________________________
None


In [7]:
from tensorflow.keras.preprocessing import image


In [8]:
def extract_features(img_path,model):
    img = image.load_img(img_path,target_size=(224,224))
    img_array = image.img_to_array(img)
    expanded_img_array = np.expand_dims(img_array, axis=0)
    preprocessed_img = preprocess_input(expanded_img_array)
    result = model.predict(preprocessed_img).flatten()
    normalized_result = result / norm(result)

    return normalized_result

In [9]:
import os

In [10]:
filenames=[]

for file in os.listdir('images'):
    filenames.append(os.path.join('images',file))
    

In [11]:
print(len(filenames))
print(filenames[0:6])

44441
['images\\10000.jpg', 'images\\10001.jpg', 'images\\10002.jpg', 'images\\10003.jpg', 'images\\10004.jpg', 'images\\10005.jpg']


In [None]:
feature_list=[]

for file in tqdm(filenames):
    feature_list.append(extract_features(file,model))

 14%|█▎        | 6017/44441 [39:48<1:16:24,  8.38it/s]    

In [None]:
pickle.dump(feature_list,open('embeddings.pkl','wb'))
pickle.dump(filenames,open('filenames','wb'))

In [None]:
print(np.array(feature_list).shape)