In [1]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing import image_dataset_from_directory

On the keras.io/api/applications, there were many models available and EfficientNetB0 seemed to be good between performance and size.
It has not a lot of parameters and is fast.

In [4]:
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS_HEAD = 5 #At first I put more but its was too long
EPOCHS_FINE = 5

DATASET_IMAGE_POKEMON = "../data/pokemon-dataset-1000"
ENREGSITREMENT_MODELE = "../models/finetuned_efficientnetb0_pour_pokemon.h5"


In [5]:
train_dataset = image_dataset_from_directory(
    DATASET_IMAGE_POKEMON,
    image_size = IMAGE_SIZE,
    batch_size = BATCH_SIZE,
    label_mode = 'categorical',
)

class_names = train_dataset.class_names
num_classes = len(class_names)

print(class_names)
print(num_classes)




Found 20921 files belonging to 1000 classes.
['abomasnow', 'abra', 'absol', 'accelgor', 'aegislash-shield', 'aerodactyl', 'aggron', 'aipom', 'alakazam', 'alcremie', 'alomomola', 'altaria', 'amaura', 'ambipom', 'amoonguss', 'ampharos', 'annihilape', 'anorith', 'appletun', 'applin', 'araquanid', 'arbok', 'arboliva', 'arcanine', 'arceus', 'archen', 'archeops', 'arctibax', 'arctovish', 'arctozolt', 'ariados', 'armaldo', 'armarouge', 'aromatisse', 'aron', 'arrokuda', 'articuno', 'audino', 'aurorus', 'avalugg', 'axew', 'azelf', 'azumarill', 'azurill', 'bagon', 'baltoy', 'banette', 'barbaracle', 'barboach', 'barraskewda', 'basculegion-male', 'basculin-red-striped', 'bastiodon', 'baxcalibur', 'bayleef', 'beartic', 'beautifly', 'beedrill', 'beheeyem', 'beldum', 'bellibolt', 'bellossom', 'bellsprout', 'bergmite', 'bewear', 'bibarel', 'bidoof', 'binacle', 'bisharp', 'blacephalon', 'blastoise', 'blaziken', 'blipbug', 'blissey', 'blitzle', 'boldore', 'boltund', 'bombirdier', 'bonsly', 'bouffalant',

In [6]:
train_dataset = train_dataset.map(lambda x, y: (preprocess_input(x), y),
                                  num_parallel_calls = tf.data.AUTOTUNE
                                  ).prefetch(buffer_size = tf.data.AUTOTUNE)

In [7]:
model_de_base = EfficientNetB0(
    include_top = False,
    weights = 'imagenet',
    input_tensor = None,
    input_shape = (224, 224, 3),
    pooling = None,
    classes = None,
    classifier_activation = None,
)

model_de_base.trainable = False

x = model_de_base.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.2)(x)

embeddings = Dense(256, name = 'embedding')(x)
output = Dense(num_classes, activation = 'softmax')(embeddings)
model = Model(inputs = model_de_base.input, outputs = output)



In [8]:
model.compile(optimizer = Adam(1e-4),
                loss = 'categorical_crossentropy',
                metrics = ['accuracy']
                )

model.fit(train_dataset, epochs = EPOCHS_HEAD)

Epoch 1/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m164s[0m 243ms/step - accuracy: 0.0231 - loss: 6.6085
Epoch 2/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m166s[0m 253ms/step - accuracy: 0.1695 - loss: 5.6361
Epoch 3/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m178s[0m 273ms/step - accuracy: 0.3377 - loss: 4.4552
Epoch 4/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m166s[0m 253ms/step - accuracy: 0.4966 - loss: 3.3675
Epoch 5/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6887s[0m 11s/step - accuracy: 0.6233 - loss: 2.5317


<keras.src.callbacks.history.History at 0x2252953c460>

In [9]:
output.shape

(None, 1000)

In [12]:
model_de_base.trainable = True

for layer in model_de_base.layers[:-30]:
    layer.trainable = False

model.compile(optimizer = Adam(1e-4),
                loss = 'categorical_crossentropy',
                metrics = ['accuracy']
                )

model.fit(train_dataset, epochs = EPOCHS_FINE)

Epoch 1/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m185s[0m 275ms/step - accuracy: 0.6846 - loss: 1.7401
Epoch 2/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m190s[0m 290ms/step - accuracy: 0.8472 - loss: 0.7872
Epoch 3/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m197s[0m 301ms/step - accuracy: 0.9075 - loss: 0.4808
Epoch 4/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 306ms/step - accuracy: 0.9429 - loss: 0.3096
Epoch 5/5
[1m654/654[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m196s[0m 299ms/step - accuracy: 0.9636 - loss: 0.2083


<keras.src.callbacks.history.History at 0x225294f00a0>

In [13]:
model.save(ENREGSITREMENT_MODELE)

