In [28]:
import tensorflow as tf
import matplotlib.pyplot as plt
from custom_data_gen import CustomDataGen
from keras import layers, models, losses
from keras.applications import MobileNetV3Small

In [29]:
model = MobileNetV3Small(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

In [30]:
model.trainable = False

In [None]:
model.summary()

In [32]:
train = CustomDataGen("/home/shared/Mammiferes_jpg", {"cat": 1.5, "dog": 1.5}, batch_size=8, split="train")
test = CustomDataGen("/home/shared/Mammiferes_jpg", {"cat": 1.5, "dog": 1.5}, batch_size=8, split="test")

In [None]:
print(f'Il y a {len(train.classes)} classes.')

In [None]:
image_batch, label_batch = next(iter(train))
feature_batch = model(image_batch)
print(feature_batch.shape)

In [None]:
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

In [None]:
prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

In [None]:
inputs = tf.keras.Input(shape=(160, 160, 3))
x = model(inputs, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

In [36]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=losses.sparse_categorical_crossentropy,
)

In [None]:
history = model.fit(
    train,
    validation_data=test,
    epochs=20,
)

In [None]:
plt.figure(figsize=(20,12))
plt.plot(history.history["loss"][2:10])
plt.plot(history.history["val_loss"][2:10])
plt.legend()