In [1]:
from tensorflow.keras.applications.efficientnet_v2 import EfficientNetV2S
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
from tensorflow.keras.models import Model
import numpy as np
import os
from tqdm import tqdm

In [2]:
DIR = "../../data/stage_clf_enh_img_manual2/"
OUT = "../../data/stage_clf_enh_img_manual2_embeddings/"

if not os.path.exists(OUT):
    os.mkdir(OUT)

In [3]:
model = EfficientNetV2S(weights='imagenet', include_top=False, pooling='avg')
# layer_name = "avg_pool"
# model = Model(inputs=base_model.input, outputs=base_model.get_layer(layer_name).output)
model.summary()

Model: "efficientnetv2-s"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 rescaling (Rescaling)          (None, None, None,   0           ['input_1[0][0]']                
                                3)                                                                
                                                                                                  
 stem_conv (Conv2D)             (None, None, None,   648         ['rescaling[0][0]']              
                                24)                                                

In [4]:
def preprocess_image(path):
    img = image.load_img(path, target_size=(384, 384))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

In [5]:
def folder_extraction(dir, out):
    files = os.listdir(dir)

    for i in tqdm(range(len(files))):
        file_name = files[i]
        img_id = file_name.split('.png')[0]
        img_path = os.path.join(dir, file_name)
        emb_path = os.path.join(out, img_id+'.npy')

        if os.path.exists(emb_path):
            continue
        
        try:
            x = preprocess_image(img_path)
            features = model.predict(x, verbose=0)
            np.save(emb_path, features)
        except OSError:
            pass

In [6]:
def dataset_extraction():
    for cls in os.listdir(DIR):
        dir = os.path.join(DIR, cls)
        out = os.path.join(OUT, cls)
        if not os.path.exists(out):
            os.mkdir(out)
        folder_extraction(dir, out)

In [7]:
dataset_extraction()

100%|██████████| 302/302 [00:00<00:00, 2359.17it/s]
  0%|          | 0/300 [00:00<?, ?it/s]

100%|██████████| 300/300 [00:00<00:00, 3370.58it/s]
